Pettingzoo support (#494)
Co-authored-by: Rodrigo de Lazcano <r.l.p.v96@gmail.com> Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
parent
d85bc19269
commit
c7e2e56fac
4
setup.py
4
setup.py
@ -55,6 +55,7 @@ setup(
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||
"pettingzoo>=1.12,<=1.13",
|
||||
],
|
||||
extras_require={
|
||||
"dev": [
|
||||
@ -74,6 +75,9 @@ setup(
|
||||
"pydocstyle",
|
||||
"doc8",
|
||||
"scipy",
|
||||
"pillow",
|
||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||
],
|
||||
"atari": ["atari_py", "opencv-python"],
|
||||
"mujoco": ["mujoco_py"],
|
||||
|
||||
@ -1,87 +0,0 @@
|
||||
import os
|
||||
import pprint
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from tic_tac_toe import get_agents, get_parser, train_agent, watch
|
||||
from tic_tac_toe_env import TicTacToeEnv
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import RandomPolicy
|
||||
from tianshou.utils import TensorboardLogger
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = get_parser()
|
||||
parser.add_argument('--self_play_round', type=int, default=20)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def gomoku(args=get_args()):
|
||||
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
policy, optim = get_agents(args)
|
||||
agent_learn = policy.policies[args.agent_id - 1]
|
||||
agent_opponent = policy.policies[2 - args.agent_id]
|
||||
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
args.logger = TensorboardLogger(writer)
|
||||
|
||||
opponent_pool = [agent_opponent]
|
||||
|
||||
def env_func():
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
for round in range(args.self_play_round):
|
||||
rews = []
|
||||
agent_learn.set_eps(0.0)
|
||||
# compute the reward over previous learner
|
||||
for opponent in opponent_pool:
|
||||
policy.replace_policy(opponent, 3 - args.agent_id)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
results = test_collector.collect(n_episode=100)
|
||||
rews.append(results['rews'].mean())
|
||||
rews = np.array(rews)
|
||||
# weight opponent by their difficulty level
|
||||
rews = np.exp(-rews * 10.0)
|
||||
rews /= np.sum(rews)
|
||||
total_epoch = args.epoch
|
||||
args.epoch = 1
|
||||
for epoch in range(total_epoch):
|
||||
# sample one opponent
|
||||
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
|
||||
print(f'selection probability {rews.tolist()}')
|
||||
print(f'selected opponent {opp_id}')
|
||||
opponent = opponent_pool[opp_id.item(0)]
|
||||
agent = RandomPolicy()
|
||||
# previous learner can only be used for forward
|
||||
agent.forward = opponent.forward
|
||||
args.model_save_path = os.path.join(
|
||||
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
|
||||
)
|
||||
result, agent_learn = train_agent(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
|
||||
)
|
||||
print(f'round_{round}_epoch_{epoch}')
|
||||
pprint.pprint(result)
|
||||
learnt_agent = deepcopy(agent_learn)
|
||||
learnt_agent.set_eps(0.0)
|
||||
opponent_pool.append(learnt_agent)
|
||||
args.epoch = total_epoch
|
||||
if __name__ == '__main__':
|
||||
# Let's watch its performance!
|
||||
opponent = opponent_pool[-2]
|
||||
watch(args, agent_learn, opponent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gomoku(get_args())
|
||||
@ -1,148 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from tianshou.env import MultiAgentEnv
|
||||
|
||||
|
||||
class TicTacToeEnv(MultiAgentEnv):
|
||||
"""This is a simple implementation of the Tic-Tac-Toe game, where two
|
||||
agents play against each other.
|
||||
|
||||
The implementation is intended to show how to wrap an environment to
|
||||
satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`.
|
||||
|
||||
:param size: the size of the board (square board)
|
||||
:param win_size: how many units in a row is considered to win
|
||||
"""
|
||||
|
||||
def __init__(self, size: int = 3, win_size: int = 3):
|
||||
super().__init__()
|
||||
assert size > 0, f'board size should be positive, but got {size}'
|
||||
self.size = size
|
||||
assert win_size > 0, f'win-size should be positive, but got {win_size}'
|
||||
self.win_size = win_size
|
||||
assert win_size <= size, f'win-size {win_size} should not ' \
|
||||
f'be larger than board size {size}'
|
||||
self.convolve_kernel = np.ones(win_size)
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=-1.0, high=1.0, shape=(size, size), dtype=np.float32
|
||||
)
|
||||
self.action_space = gym.spaces.Discrete(size * size)
|
||||
self.current_board = None
|
||||
self.current_agent = None
|
||||
self._last_move = None
|
||||
self.step_num = None
|
||||
|
||||
def reset(self) -> dict:
|
||||
self.current_board = np.zeros((self.size, self.size), dtype=np.int32)
|
||||
self.current_agent = 1
|
||||
self._last_move = (-1, -1)
|
||||
self.step_num = 0
|
||||
return {
|
||||
'agent_id': self.current_agent,
|
||||
'obs': np.array(self.current_board),
|
||||
'mask': self.current_board.flatten() == 0
|
||||
}
|
||||
|
||||
def step(self, action: [int,
|
||||
np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
|
||||
if self.current_agent is None:
|
||||
raise ValueError("calling step() of unreset environment is prohibited!")
|
||||
assert 0 <= action < self.size * self.size
|
||||
assert self.current_board.item(action) == 0
|
||||
_current_agent = self.current_agent
|
||||
self._move(action)
|
||||
mask = self.current_board.flatten() == 0
|
||||
is_win, is_opponent_win = False, False
|
||||
is_win = self._test_win()
|
||||
# the game is over when one wins or there is only one empty place
|
||||
done = is_win
|
||||
if sum(mask) == 1:
|
||||
done = True
|
||||
self._move(np.where(mask)[0][0])
|
||||
is_opponent_win = self._test_win()
|
||||
if is_win:
|
||||
reward = 1
|
||||
elif is_opponent_win:
|
||||
reward = -1
|
||||
else:
|
||||
reward = 0
|
||||
obs = {
|
||||
'agent_id': self.current_agent,
|
||||
'obs': np.array(self.current_board),
|
||||
'mask': mask
|
||||
}
|
||||
rew_agent_1 = reward if _current_agent == 1 else (-reward)
|
||||
rew_agent_2 = reward if _current_agent == 2 else (-reward)
|
||||
vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32)
|
||||
if done:
|
||||
self.current_agent = None
|
||||
return obs, vec_rew, np.array(done), {}
|
||||
|
||||
def _move(self, action):
|
||||
row, col = action // self.size, action % self.size
|
||||
if self.current_agent == 1:
|
||||
self.current_board[row, col] = 1
|
||||
else:
|
||||
self.current_board[row, col] = -1
|
||||
self.current_agent = 3 - self.current_agent
|
||||
self._last_move = (row, col)
|
||||
self.step_num += 1
|
||||
|
||||
def _test_win(self):
|
||||
"""test if someone wins by checking the situation around last move"""
|
||||
row, col = self._last_move
|
||||
rboard = self.current_board[row, :]
|
||||
cboard = self.current_board[:, col]
|
||||
current = self.current_board[row, col]
|
||||
rightup = [
|
||||
self.current_board[row - i, col + i] for i in range(1, self.size - col)
|
||||
if row - i >= 0
|
||||
]
|
||||
leftdown = [
|
||||
self.current_board[row + i, col - i] for i in range(1, col + 1)
|
||||
if row + i < self.size
|
||||
]
|
||||
rdiag = np.array(leftdown[::-1] + [current] + rightup)
|
||||
rightdown = [
|
||||
self.current_board[row + i, col + i] for i in range(1, self.size - col)
|
||||
if row + i < self.size
|
||||
]
|
||||
leftup = [
|
||||
self.current_board[row - i, col - i] for i in range(1, col + 1)
|
||||
if row - i >= 0
|
||||
]
|
||||
diag = np.array(leftup[::-1] + [current] + rightdown)
|
||||
results = [
|
||||
np.convolve(k, self.convolve_kernel, mode='valid')
|
||||
for k in (rboard, cboard, rdiag, diag)
|
||||
]
|
||||
return any([(np.abs(x) == self.win_size).any() for x in results])
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> int:
|
||||
pass
|
||||
|
||||
def render(self, **kwargs) -> None:
|
||||
print(f'board (step {self.step_num}):')
|
||||
pad = '==='
|
||||
top = pad + '=' * (2 * self.size - 1) + pad
|
||||
print(top)
|
||||
|
||||
def f(i, data):
|
||||
j, number = data
|
||||
last_move = i == self._last_move[0] and j == self._last_move[1]
|
||||
if number == 1:
|
||||
return 'X' if last_move else 'x'
|
||||
if number == -1:
|
||||
return 'O' if last_move else 'o'
|
||||
return '_'
|
||||
|
||||
for i, row in enumerate(self.current_board):
|
||||
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
|
||||
print(top)
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
185
test/pettingzoo/pistonball.py
Normal file
185
test/pettingzoo/pistonball.py
Normal file
@ -0,0 +1,185 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=2000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument(
|
||||
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n-pistons',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of pistons(agents) in the env'
|
||||
)
|
||||
parser.add_argument('--n-step', type=int, default=100)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=3)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=100)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.0)
|
||||
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='no training, '
|
||||
'watch the play of pre-trained models'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = get_parser()
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def get_env(args: argparse.Namespace = get_args()):
|
||||
return PettingZooEnv(pistonball_v4.env(continuous=False, n_pistons=args.n_pistons))
|
||||
|
||||
|
||||
def get_agents(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agents: Optional[List[BasePolicy]] = None,
|
||||
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
|
||||
env = get_env()
|
||||
observation_space = env.observation_space['observation'] if isinstance(
|
||||
env.observation_space, gym.spaces.Dict
|
||||
) else env.observation_space
|
||||
args.state_shape = observation_space.shape or observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if agents is None:
|
||||
agents = []
|
||||
optims = []
|
||||
for _ in range(args.n_pistons):
|
||||
# model
|
||||
net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
agent = DQNPolicy(
|
||||
net,
|
||||
optim,
|
||||
args.gamma,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
)
|
||||
agents.append(agent)
|
||||
optims.append(optim)
|
||||
|
||||
policy = MultiAgentPolicyManager(agents, env)
|
||||
return policy, optims, env.agents
|
||||
|
||||
|
||||
def train_agent(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agents: Optional[List[BasePolicy]] = None,
|
||||
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
|
||||
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
pass
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
|
||||
def reward_metric(rews):
|
||||
return rews[:, 0]
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
update_per_step=args.update_per_step,
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
reward_metric=reward_metric
|
||||
)
|
||||
|
||||
return result, policy
|
||||
|
||||
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = get_env()
|
||||
policy.eval()
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")
|
||||
276
test/pettingzoo/pistonball_continuous.py
Normal file
276
test/pettingzoo/pistonball_continuous.py
Normal file
@ -0,0 +1,276 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Independent, Normal
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""Reference: Human-level control through deep reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.c = c
|
||||
self.h = h
|
||||
self.w = w
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
||||
nn.Flatten()
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x.reshape(-1, self.c, self.w, self.h)), state
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=2000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument(
|
||||
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--n-pistons',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of pistons(agents) in the env'
|
||||
)
|
||||
parser.add_argument('--n-step', type=int, default=100)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=16)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=1000)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=1000)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='no training, '
|
||||
'watch the play of pre-trained models'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
# ppo special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.25)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--resume', action="store_true")
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
parser.add_argument('--render', type=float, default=0.0)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = get_parser()
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def get_env(args: argparse.Namespace = get_args()):
|
||||
return PettingZooEnv(pistonball_v4.env(continuous=True, n_pistons=args.n_pistons))
|
||||
|
||||
|
||||
def get_agents(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agents: Optional[List[BasePolicy]] = None,
|
||||
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
|
||||
env = get_env()
|
||||
observation_space = env.observation_space['observation'] if isinstance(
|
||||
env.observation_space, gym.spaces.Dict
|
||||
) else env.observation_space
|
||||
args.state_shape = observation_space.shape or observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
|
||||
if agents is None:
|
||||
agents = []
|
||||
optims = []
|
||||
for _ in range(args.n_pistons):
|
||||
# model
|
||||
net = DQN(
|
||||
observation_space.shape[2],
|
||||
observation_space.shape[1],
|
||||
observation_space.shape[0],
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
|
||||
actor = ActorProb(
|
||||
net, args.action_shape, max_action=args.max_action, device=args.device
|
||||
).to(args.device)
|
||||
net2 = DQN(
|
||||
observation_space.shape[2],
|
||||
observation_space.shape[1],
|
||||
observation_space.shape[0],
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
critic = Critic(net2, device=args.device).to(args.device)
|
||||
for m in set(actor.modules()).union(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(
|
||||
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
||||
)
|
||||
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
agent = PPOPolicy(
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist,
|
||||
discount_factor=args.gamma,
|
||||
max_grad_norm=args.max_grad_norm,
|
||||
eps_clip=args.eps_clip,
|
||||
vf_coef=args.vf_coef,
|
||||
ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv,
|
||||
# dual_clip=args.dual_clip,
|
||||
# dual clip cause monotonically increasing log_std :)
|
||||
value_clip=args.value_clip,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space
|
||||
)
|
||||
|
||||
agents.append(agent)
|
||||
optims.append(optim)
|
||||
|
||||
policy = MultiAgentPolicyManager(
|
||||
agents, env, action_scaling=True, action_bound_method='clip'
|
||||
)
|
||||
return policy, optims, env.agents
|
||||
|
||||
|
||||
def train_agent(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agents: Optional[List[BasePolicy]] = None,
|
||||
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
|
||||
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=False # True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
pass
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
|
||||
def reward_metric(rews):
|
||||
return rews[:, 0]
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.repeat_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume
|
||||
)
|
||||
|
||||
return result, policy
|
||||
|
||||
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = get_env()
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")
|
||||
21
test/pettingzoo/test_pistonball.py
Normal file
21
test/pettingzoo/test_pistonball.py
Normal file
@ -0,0 +1,21 @@
|
||||
import pprint
|
||||
|
||||
from pistonball import get_args, train_agent, watch
|
||||
|
||||
|
||||
def test_piston_ball(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
result, agent = train_agent(args)
|
||||
# assert result["best_reward"] >= args.win_rate
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_piston_ball(get_args())
|
||||
21
test/pettingzoo/test_pistonball_continuous.py
Normal file
21
test/pettingzoo/test_pistonball_continuous.py
Normal file
@ -0,0 +1,21 @@
|
||||
import pprint
|
||||
|
||||
from pistonball_continuous import get_args, train_agent, watch
|
||||
|
||||
|
||||
def test_piston_ball_continuous(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
result, agent = train_agent(args)
|
||||
assert result["best_reward"] >= 30.0
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_piston_ball_continuous(get_args())
|
||||
@ -1,21 +1,21 @@
|
||||
import pprint
|
||||
|
||||
from tic_tac_toe import get_args, train_agent, watch
|
||||
|
||||
|
||||
def test_tic_tac_toe(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
result, agent = train_agent(args)
|
||||
assert result["best_reward"] >= args.win_rate
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tic_tac_toe(get_args())
|
||||
import pprint
|
||||
|
||||
from tic_tac_toe import get_args, train_agent, watch
|
||||
|
||||
|
||||
def test_tic_tac_toe(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
result, agent = train_agent(args)
|
||||
assert result["best_reward"] >= args.win_rate
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_tic_tac_toe(get_args())
|
||||
@ -1,232 +1,241 @@
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tic_tac_toe_env import TicTacToeEnv
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import (
|
||||
BasePolicy,
|
||||
DQNPolicy,
|
||||
MultiAgentPolicyManager,
|
||||
RandomPolicy,
|
||||
)
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument(
|
||||
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||
)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=20)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument(
|
||||
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||
)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.1)
|
||||
parser.add_argument('--board-size', type=int, default=6)
|
||||
parser.add_argument('--win-size', type=int, default=4)
|
||||
parser.add_argument(
|
||||
'--win-rate', type=float, default=0.9, help='the expected winning rate'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='no training, '
|
||||
'watch the play of pre-trained models'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--agent-id',
|
||||
type=int,
|
||||
default=2,
|
||||
help='the learned agent plays as the'
|
||||
' agent_id-th player. Choices are 1 and 2.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--resume-path',
|
||||
type=str,
|
||||
default='',
|
||||
help='the path of agent pth file '
|
||||
'for resuming from a pre-trained agent'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opponent-path',
|
||||
type=str,
|
||||
default='',
|
||||
help='the path of opponent agent pth file '
|
||||
'for resuming from a pre-trained agent'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = get_parser()
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def get_agents(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
|
||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if agent_learn is None:
|
||||
# model
|
||||
net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
if optim is None:
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
agent_learn = DQNPolicy(
|
||||
net,
|
||||
optim,
|
||||
args.gamma,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
)
|
||||
if args.resume_path:
|
||||
agent_learn.load_state_dict(torch.load(args.resume_path))
|
||||
|
||||
if agent_opponent is None:
|
||||
if args.opponent_path:
|
||||
agent_opponent = deepcopy(agent_learn)
|
||||
agent_opponent.load_state_dict(torch.load(args.opponent_path))
|
||||
else:
|
||||
agent_opponent = RandomPolicy()
|
||||
|
||||
if args.agent_id == 1:
|
||||
agents = [agent_learn, agent_opponent]
|
||||
else:
|
||||
agents = [agent_opponent, agent_learn]
|
||||
policy = MultiAgentPolicyManager(agents)
|
||||
return policy, optim
|
||||
|
||||
|
||||
def train_agent(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
|
||||
def env_func():
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
|
||||
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
policy, optim = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
|
||||
)
|
||||
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
if hasattr(args, 'model_save_path'):
|
||||
model_save_path = args.model_save_path
|
||||
else:
|
||||
model_save_path = os.path.join(
|
||||
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
|
||||
)
|
||||
torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path)
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.win_rate
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||
|
||||
def reward_metric(rews):
|
||||
return rews[:, args.agent_id - 1]
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
update_per_step=args.update_per_step,
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
reward_metric=reward_metric
|
||||
)
|
||||
|
||||
return result, policy.policies[args.agent_id - 1]
|
||||
|
||||
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
||||
policy, optim = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||
)
|
||||
policy.eval()
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from pettingzoo.classic import tictactoe_v3
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
from tianshou.policy import (
|
||||
BasePolicy,
|
||||
DQNPolicy,
|
||||
MultiAgentPolicyManager,
|
||||
RandomPolicy,
|
||||
)
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_env():
|
||||
return PettingZooEnv(tictactoe_v3.env())
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-4)
|
||||
parser.add_argument(
|
||||
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||
)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=50)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument(
|
||||
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||
)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.1)
|
||||
parser.add_argument(
|
||||
'--win-rate',
|
||||
type=float,
|
||||
default=0.6,
|
||||
help='the expected winning rate: Optimal policy can get 0.7'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--watch',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='no training, '
|
||||
'watch the play of pre-trained models'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--agent-id',
|
||||
type=int,
|
||||
default=2,
|
||||
help='the learned agent plays as the'
|
||||
' agent_id-th player. Choices are 1 and 2.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--resume-path',
|
||||
type=str,
|
||||
default='',
|
||||
help='the path of agent pth file '
|
||||
'for resuming from a pre-trained agent'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--opponent-path',
|
||||
type=str,
|
||||
default='',
|
||||
help='the path of opponent agent pth file '
|
||||
'for resuming from a pre-trained agent'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = get_parser()
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
def get_agents(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
|
||||
env = get_env()
|
||||
observation_space = env.observation_space['observation'] if isinstance(
|
||||
env.observation_space, gym.spaces.Dict
|
||||
) else env.observation_space
|
||||
args.state_shape = observation_space.shape or observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if agent_learn is None:
|
||||
# model
|
||||
net = Net(
|
||||
args.state_shape,
|
||||
args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device
|
||||
).to(args.device)
|
||||
if optim is None:
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
agent_learn = DQNPolicy(
|
||||
net,
|
||||
optim,
|
||||
args.gamma,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
)
|
||||
if args.resume_path:
|
||||
agent_learn.load_state_dict(torch.load(args.resume_path))
|
||||
|
||||
if agent_opponent is None:
|
||||
if args.opponent_path:
|
||||
agent_opponent = deepcopy(agent_learn)
|
||||
agent_opponent.load_state_dict(torch.load(args.opponent_path))
|
||||
else:
|
||||
agent_opponent = RandomPolicy()
|
||||
|
||||
if args.agent_id == 1:
|
||||
agents = [agent_learn, agent_opponent]
|
||||
else:
|
||||
agents = [agent_opponent, agent_learn]
|
||||
policy = MultiAgentPolicyManager(agents, env)
|
||||
return policy, optim, env.agents
|
||||
|
||||
|
||||
def train_agent(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
optim: Optional[torch.optim.Optimizer] = None,
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
|
||||
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
policy, optim, agents = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
|
||||
)
|
||||
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = TensorboardLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
if hasattr(args, 'model_save_path'):
|
||||
model_save_path = args.model_save_path
|
||||
else:
|
||||
model_save_path = os.path.join(
|
||||
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
|
||||
)
|
||||
torch.save(
|
||||
policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path
|
||||
)
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.win_rate
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||
|
||||
def reward_metric(rews):
|
||||
return rews[:, args.agent_id - 1]
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy,
|
||||
train_collector,
|
||||
test_collector,
|
||||
args.epoch,
|
||||
args.step_per_epoch,
|
||||
args.step_per_collect,
|
||||
args.test_num,
|
||||
args.batch_size,
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
save_fn=save_fn,
|
||||
update_per_step=args.update_per_step,
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
reward_metric=reward_metric
|
||||
)
|
||||
|
||||
return result, policy.policies[agents[args.agent_id - 1]]
|
||||
|
||||
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(),
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = get_env()
|
||||
policy, optim, agents = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||
)
|
||||
policy.eval()
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||
4
tianshou/env/__init__.py
vendored
4
tianshou/env/__init__.py
vendored
@ -1,6 +1,6 @@
|
||||
"""Env package."""
|
||||
|
||||
from tianshou.env.maenv import MultiAgentEnv
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
from tianshou.env.venvs import (
|
||||
BaseVectorEnv,
|
||||
DummyVectorEnv,
|
||||
@ -15,5 +15,5 @@ __all__ = [
|
||||
"SubprocVectorEnv",
|
||||
"ShmemVectorEnv",
|
||||
"RayVectorEnv",
|
||||
"MultiAgentEnv",
|
||||
"PettingZooEnv",
|
||||
]
|
||||
|
||||
65
tianshou/env/maenv.py
vendored
65
tianshou/env/maenv.py
vendored
@ -1,65 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MultiAgentEnv(ABC, gym.Env):
|
||||
"""The interface for multi-agent environments.
|
||||
|
||||
Multi-agent environments must be wrapped as
|
||||
:class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
|
||||
::
|
||||
|
||||
env = MultiAgentEnv(...)
|
||||
# obs is a dict containing obs, agent_id, and mask
|
||||
obs = env.reset()
|
||||
act = policy(obs)
|
||||
obs, rew, done, info = env.step(act)
|
||||
env.close()
|
||||
|
||||
The available action's mask is set to 1, otherwise it is set to 0. Further
|
||||
usage can be found at :ref:`marl_example`.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> dict:
|
||||
"""Reset the state.
|
||||
|
||||
Return the initial state, first agent_id, and the initial action set,
|
||||
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(
|
||||
self, action: np.ndarray
|
||||
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of the environment’s dynamics.
|
||||
|
||||
When the end of episode is reached, you are responsible for calling
|
||||
reset() to reset the environment’s state.
|
||||
|
||||
Accept action and return a tuple (obs, rew, done, info).
|
||||
|
||||
:param numpy.ndarray action: action provided by a agent.
|
||||
|
||||
:return: A tuple including four items:
|
||||
|
||||
* ``obs`` a dict containing obs, agent_id, and mask, which means \
|
||||
that it is the ``agent_id`` player's turn to play with ``obs``\
|
||||
observation and ``mask``.
|
||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||
previous actions. Depending on the specific environment, this \
|
||||
can be either a scalar reward for current agent or a vector \
|
||||
reward for all the agents.
|
||||
* ``done`` a numpy.ndarray, whether the episode has ended, in \
|
||||
which case further step() calls will return undefined results
|
||||
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
|
||||
information (helpful for debugging, and sometimes learning)
|
||||
"""
|
||||
pass
|
||||
112
tianshou/env/pettingzoo_env.py
vendored
Normal file
112
tianshou/env/pettingzoo_env.py
vendored
Normal file
@ -0,0 +1,112 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import gym.spaces
|
||||
from pettingzoo.utils.env import AECEnv
|
||||
from pettingzoo.utils.wrappers import BaseWrapper
|
||||
|
||||
|
||||
class PettingZooEnv(AECEnv, gym.Env, ABC):
|
||||
"""The interface for petting zoo environments.
|
||||
|
||||
Multi-agent environments must be wrapped as
|
||||
:class:`~tianshou.env.PettingZooEnv`. Here is the usage:
|
||||
::
|
||||
|
||||
env = PettingZooEnv(...)
|
||||
# obs is a dict containing obs, agent_id, and mask
|
||||
obs = env.reset()
|
||||
action = policy(obs)
|
||||
obs, rew, done, info = env.step(action)
|
||||
env.close()
|
||||
|
||||
The available action's mask is set to True, otherwise it is set to False.
|
||||
Further usage can be found at :ref:`marl_example`.
|
||||
"""
|
||||
|
||||
def __init__(self, env: BaseWrapper):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
# agent idx list
|
||||
self.agents = self.env.possible_agents
|
||||
self.agent_idx = {}
|
||||
for i, agent_id in enumerate(self.agents):
|
||||
self.agent_idx[agent_id] = i
|
||||
# Get dictionaries of obs_spaces and act_spaces
|
||||
self.observation_spaces = self.env.observation_spaces
|
||||
self.action_spaces = self.env.action_spaces
|
||||
|
||||
self.rewards = [0] * len(self.agents)
|
||||
|
||||
# Get first observation space, assuming all agents have equal space
|
||||
self.observation_space: Any = self.observation_space(self.agents[0])
|
||||
|
||||
# Get first action space, assuming all agents have equal space
|
||||
self.action_space: Any = self.action_space(self.agents[0])
|
||||
|
||||
assert all(self.env.observation_space(agent) == self.observation_space
|
||||
for agent in self.agents), \
|
||||
"Observation spaces for all agents must be identical. Perhaps " \
|
||||
"SuperSuit's pad_observations wrapper can help (useage: " \
|
||||
"`supersuit.aec_wrappers.pad_observations(env)`"
|
||||
|
||||
assert all(self.env.action_space(agent) == self.action_space
|
||||
for agent in self.agents), \
|
||||
"Action spaces for all agents must be identical. Perhaps " \
|
||||
"SuperSuit's pad_action_space wrapper can help (useage: " \
|
||||
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> dict:
|
||||
self.env.reset()
|
||||
observation = self.env.observe(self.env.agent_selection)
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
return {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation['observation'],
|
||||
'mask':
|
||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||
}
|
||||
else:
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
return {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
||||
}
|
||||
else:
|
||||
return {'agent_id': self.env.agent_selection, 'obs': observation}
|
||||
|
||||
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
|
||||
self.env.step(action)
|
||||
observation, rew, done, info = self.env.last()
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
obs = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation['observation'],
|
||||
'mask':
|
||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||
}
|
||||
else:
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
obs = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
||||
}
|
||||
else:
|
||||
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
|
||||
|
||||
for agent_id, reward in self.env.rewards.items():
|
||||
self.rewards[self.agent_idx[agent_id]] = reward
|
||||
return obs, self.rewards, done, info
|
||||
|
||||
def close(self) -> None:
|
||||
self.env.close()
|
||||
|
||||
def seed(self, seed: Any = None) -> None:
|
||||
self.env.seed(seed)
|
||||
|
||||
def render(self, mode: str = "human") -> Any:
|
||||
return self.env.render(mode)
|
||||
6
tianshou/env/venvs.py
vendored
6
tianshou/env/venvs.py
vendored
@ -2,6 +2,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pettingzoo
|
||||
|
||||
from tianshou.env.worker import (
|
||||
DummyEnvWorker,
|
||||
@ -364,7 +365,10 @@ class DummyVectorEnv(BaseVectorEnv):
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@ -16,21 +17,29 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
:ref:`marl_example` can help you better understand this procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.policies = policies
|
||||
def __init__(
|
||||
self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(action_space=env.action_space, **kwargs)
|
||||
assert (
|
||||
len(policies) == len(env.agents)
|
||||
), "One policy must be assigned for each agent."
|
||||
|
||||
self.agent_idx = env.agent_idx
|
||||
for i, policy in enumerate(policies):
|
||||
# agent_id 0 is reserved for the environment proxy
|
||||
# (this MultiAgentPolicyManager)
|
||||
policy.set_agent_id(i + 1)
|
||||
policy.set_agent_id(env.agents[i])
|
||||
|
||||
self.policies = dict(zip(env.agents, policies))
|
||||
|
||||
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
|
||||
"""Replace the "agent_id"th policy in this manager."""
|
||||
self.policies[agent_id - 1] = policy
|
||||
policy.set_agent_id(agent_id)
|
||||
self.policies[agent_id] = policy
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
|
||||
|
||||
@ -45,18 +54,21 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
# Since we do not override buffer.__setattr__, here we use _meta to
|
||||
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
|
||||
save_rew, buffer._meta.rew = buffer.rew, Batch()
|
||||
for policy in self.policies:
|
||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||
for agent, policy in self.policies.items():
|
||||
agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
|
||||
if len(agent_index) == 0:
|
||||
results[f"agent_{policy.agent_id}"] = Batch()
|
||||
results[agent] = Batch()
|
||||
continue
|
||||
tmp_batch, tmp_indices = batch[agent_index], indices[agent_index]
|
||||
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
|
||||
if has_rew:
|
||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
||||
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
|
||||
results[f"agent_{policy.agent_id}"] = policy.process_fn(
|
||||
tmp_batch, buffer, tmp_indices
|
||||
)
|
||||
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
|
||||
buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
|
||||
if not hasattr(tmp_batch.obs, "mask"):
|
||||
if hasattr(tmp_batch.obs, 'obs'):
|
||||
tmp_batch.obs = tmp_batch.obs.obs
|
||||
if hasattr(tmp_batch.obs_next, 'obs'):
|
||||
tmp_batch.obs_next = tmp_batch.obs_next.obs
|
||||
results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice)
|
||||
if has_rew: # restore from save_rew
|
||||
buffer._meta.rew = save_rew
|
||||
return Batch(results)
|
||||
@ -64,8 +76,8 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
||||
batch: Batch) -> Union[np.ndarray, Batch]:
|
||||
"""Add exploration noise from sub-policy onto act."""
|
||||
for policy in self.policies:
|
||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||
for agent_id, policy in self.policies.items():
|
||||
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
|
||||
if len(agent_index) == 0:
|
||||
continue
|
||||
act[agent_index] = policy.exploration_noise(
|
||||
@ -104,7 +116,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
"""
|
||||
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
|
||||
Batch]] = []
|
||||
for policy in self.policies:
|
||||
for agent_id, policy in self.policies.items():
|
||||
# This part of code is difficult to understand.
|
||||
# Let's follow an example with two agents
|
||||
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
|
||||
@ -112,7 +124,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
# agent_index for agent 1 is [0, 2, 4]
|
||||
# agent_index for agent 2 is [1, 3, 5]
|
||||
# we separate the transition of each agent according to agent_id
|
||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
|
||||
if len(agent_index) == 0:
|
||||
# (has_data, agent_index, out, act, state)
|
||||
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
||||
@ -120,11 +132,15 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
tmp_batch = batch[agent_index]
|
||||
if isinstance(tmp_batch.rew, np.ndarray):
|
||||
# reward can be empty Batch (after initial reset) or nparray.
|
||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
||||
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
|
||||
if not hasattr(tmp_batch.obs, "mask"):
|
||||
if hasattr(tmp_batch.obs, 'obs'):
|
||||
tmp_batch.obs = tmp_batch.obs.obs
|
||||
if hasattr(tmp_batch.obs_next, 'obs'):
|
||||
tmp_batch.obs_next = tmp_batch.obs_next.obs
|
||||
out = policy(
|
||||
batch=tmp_batch,
|
||||
state=None if state is None else state["agent_" +
|
||||
str(policy.agent_id)],
|
||||
state=None if state is None else state[agent_id],
|
||||
**kwargs
|
||||
)
|
||||
act = out.act
|
||||
@ -141,12 +157,12 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
]
|
||||
)
|
||||
state_dict, out_dict = {}, {}
|
||||
for policy, (has_data, agent_index, out, act,
|
||||
state) in zip(self.policies, results):
|
||||
for (agent_id, _), (has_data, agent_index, out, act,
|
||||
state) in zip(self.policies.items(), results):
|
||||
if has_data:
|
||||
holder.act[agent_index] = act
|
||||
state_dict["agent_" + str(policy.agent_id)] = state
|
||||
out_dict["agent_" + str(policy.agent_id)] = out
|
||||
state_dict[agent_id] = state
|
||||
out_dict[agent_id] = out
|
||||
holder["out"] = out_dict
|
||||
holder["state"] = state_dict
|
||||
return holder
|
||||
@ -168,10 +184,10 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
}
|
||||
"""
|
||||
results = {}
|
||||
for policy in self.policies:
|
||||
data = batch[f"agent_{policy.agent_id}"]
|
||||
for agent_id, policy in self.policies.items():
|
||||
data = batch[agent_id]
|
||||
if not data.is_empty():
|
||||
out = policy.learn(batch=data, **kwargs)
|
||||
for k, v in out.items():
|
||||
results["agent_" + str(policy.agent_id) + "/" + k] = v
|
||||
results[agent_id + "/" + k] = v
|
||||
return results
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user