Pettingzoo support (#494)

Co-authored-by: Rodrigo de Lazcano <r.l.p.v96@gmail.com>
Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
Mohammad Mahdi Rahimi 2022-02-15 17:56:45 +03:00 committed by GitHub
parent d85bc19269
commit c7e2e56fac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 933 additions and 585 deletions

View File

@ -55,6 +55,7 @@ setup(
"torch>=1.4.0", "torch>=1.4.0",
"numba>=0.51.0", "numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements "h5py>=2.10.0", # to match tensorflow's minimal requirements
"pettingzoo>=1.12,<=1.13",
], ],
extras_require={ extras_require={
"dev": [ "dev": [
@ -74,6 +75,9 @@ setup(
"pydocstyle", "pydocstyle",
"doc8", "doc8",
"scipy", "scipy",
"pillow",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
], ],
"atari": ["atari_py", "opencv-python"], "atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"], "mujoco": ["mujoco_py"],

View File

@ -1,87 +0,0 @@
import os
import pprint
from copy import deepcopy
import numpy as np
from tic_tac_toe import get_agents, get_parser, train_agent, watch
from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv
from tianshou.policy import RandomPolicy
from tianshou.utils import TensorboardLogger
def get_args():
parser = get_parser()
parser.add_argument('--self_play_round', type=int, default=20)
args = parser.parse_known_args()[0]
return args
def gomoku(args=get_args()):
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
if args.watch:
watch(args)
return
policy, optim = get_agents(args)
agent_learn = policy.policies[args.agent_id - 1]
agent_opponent = policy.policies[2 - args.agent_id]
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
writer = SummaryWriter(log_path)
args.logger = TensorboardLogger(writer)
opponent_pool = [agent_opponent]
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for round in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
for opponent in opponent_pool:
policy.replace_policy(opponent, 3 - args.agent_id)
test_collector = Collector(policy, test_envs)
results = test_collector.collect(n_episode=100)
rews.append(results['rews'].mean())
rews = np.array(rews)
# weight opponent by their difficulty level
rews = np.exp(-rews * 10.0)
rews /= np.sum(rews)
total_epoch = args.epoch
args.epoch = 1
for epoch in range(total_epoch):
# sample one opponent
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
print(f'selection probability {rews.tolist()}')
print(f'selected opponent {opp_id}')
opponent = opponent_pool[opp_id.item(0)]
agent = RandomPolicy()
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
)
result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
)
print(f'round_{round}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
opponent_pool.append(learnt_agent)
args.epoch = total_epoch
if __name__ == '__main__':
# Let's watch its performance!
opponent = opponent_pool[-2]
watch(args, agent_learn, opponent)
if __name__ == '__main__':
gomoku(get_args())

View File

@ -1,148 +0,0 @@
from functools import partial
from typing import Optional, Tuple
import gym
import numpy as np
from tianshou.env import MultiAgentEnv
class TicTacToeEnv(MultiAgentEnv):
"""This is a simple implementation of the Tic-Tac-Toe game, where two
agents play against each other.
The implementation is intended to show how to wrap an environment to
satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`.
:param size: the size of the board (square board)
:param win_size: how many units in a row is considered to win
"""
def __init__(self, size: int = 3, win_size: int = 3):
super().__init__()
assert size > 0, f'board size should be positive, but got {size}'
self.size = size
assert win_size > 0, f'win-size should be positive, but got {win_size}'
self.win_size = win_size
assert win_size <= size, f'win-size {win_size} should not ' \
f'be larger than board size {size}'
self.convolve_kernel = np.ones(win_size)
self.observation_space = gym.spaces.Box(
low=-1.0, high=1.0, shape=(size, size), dtype=np.float32
)
self.action_space = gym.spaces.Discrete(size * size)
self.current_board = None
self.current_agent = None
self._last_move = None
self.step_num = None
def reset(self) -> dict:
self.current_board = np.zeros((self.size, self.size), dtype=np.int32)
self.current_agent = 1
self._last_move = (-1, -1)
self.step_num = 0
return {
'agent_id': self.current_agent,
'obs': np.array(self.current_board),
'mask': self.current_board.flatten() == 0
}
def step(self, action: [int,
np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
if self.current_agent is None:
raise ValueError("calling step() of unreset environment is prohibited!")
assert 0 <= action < self.size * self.size
assert self.current_board.item(action) == 0
_current_agent = self.current_agent
self._move(action)
mask = self.current_board.flatten() == 0
is_win, is_opponent_win = False, False
is_win = self._test_win()
# the game is over when one wins or there is only one empty place
done = is_win
if sum(mask) == 1:
done = True
self._move(np.where(mask)[0][0])
is_opponent_win = self._test_win()
if is_win:
reward = 1
elif is_opponent_win:
reward = -1
else:
reward = 0
obs = {
'agent_id': self.current_agent,
'obs': np.array(self.current_board),
'mask': mask
}
rew_agent_1 = reward if _current_agent == 1 else (-reward)
rew_agent_2 = reward if _current_agent == 2 else (-reward)
vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32)
if done:
self.current_agent = None
return obs, vec_rew, np.array(done), {}
def _move(self, action):
row, col = action // self.size, action % self.size
if self.current_agent == 1:
self.current_board[row, col] = 1
else:
self.current_board[row, col] = -1
self.current_agent = 3 - self.current_agent
self._last_move = (row, col)
self.step_num += 1
def _test_win(self):
"""test if someone wins by checking the situation around last move"""
row, col = self._last_move
rboard = self.current_board[row, :]
cboard = self.current_board[:, col]
current = self.current_board[row, col]
rightup = [
self.current_board[row - i, col + i] for i in range(1, self.size - col)
if row - i >= 0
]
leftdown = [
self.current_board[row + i, col - i] for i in range(1, col + 1)
if row + i < self.size
]
rdiag = np.array(leftdown[::-1] + [current] + rightup)
rightdown = [
self.current_board[row + i, col + i] for i in range(1, self.size - col)
if row + i < self.size
]
leftup = [
self.current_board[row - i, col - i] for i in range(1, col + 1)
if row - i >= 0
]
diag = np.array(leftup[::-1] + [current] + rightdown)
results = [
np.convolve(k, self.convolve_kernel, mode='valid')
for k in (rboard, cboard, rdiag, diag)
]
return any([(np.abs(x) == self.win_size).any() for x in results])
def seed(self, seed: Optional[int] = None) -> int:
pass
def render(self, **kwargs) -> None:
print(f'board (step {self.step_num}):')
pad = '==='
top = pad + '=' * (2 * self.size - 1) + pad
print(top)
def f(i, data):
j, number = data
last_move = i == self._last_move[0] and j == self._last_move[1]
if number == 1:
return 'X' if last_move else 'x'
if number == -1:
return 'O' if last_move else 'o'
return '_'
for i, row in enumerate(self.current_board):
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
print(top)
def close(self) -> None:
pass

View File

@ -0,0 +1,185 @@
import argparse
import os
from typing import List, Optional, Tuple
import gym
import numpy as np
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=2000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument(
'--n-pistons',
type=int,
default=3,
help='Number of pistons(agents) in the env'
)
parser.add_argument('--n-step', type=int, default=100)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=3)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.0)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_env(args: argparse.Namespace = get_args()):
return PettingZooEnv(pistonball_v4.env(continuous=False, n_pistons=args.n_pistons))
def get_agents(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
env.observation_space, gym.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if agents is None:
agents = []
optims = []
for _ in range(args.n_pistons):
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
)
agents.append(agent)
optims.append(optim)
policy = MultiAgentPolicyManager(agents, env)
return policy, optims, env.agents
def train_agent(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[dict, BasePolicy]:
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
pass
def stop_fn(mean_rewards):
return False
def train_fn(epoch, env_step):
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
def test_fn(epoch, env_step):
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
def reward_metric(rews):
return rews[:, 0]
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
update_per_step=args.update_per_step,
logger=logger,
test_in_train=False,
reward_metric=reward_metric
)
return result, policy
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")

View File

@ -0,0 +1,276 @@
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
import torch
import torch.nn as nn
from torch.distributions import Independent, Normal
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.continuous import ActorProb, Critic
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.c = c
self.h = h
self.w = w
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x.reshape(-1, self.c, self.w, self.h)), state
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=2000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument(
'--n-pistons',
type=int,
default=3,
help='Number of pistons(agents) in the env'
)
parser.add_argument('--n-step', type=int, default=100)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--episode-per-collect', type=int, default=16)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=1000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=1000)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
parser.add_argument('--render', type=float, default=0.0)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_env(args: argparse.Namespace = get_args()):
return PettingZooEnv(pistonball_v4.env(continuous=True, n_pistons=args.n_pistons))
def get_agents(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
env.observation_space, gym.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
if agents is None:
agents = []
optims = []
for _ in range(args.n_pistons):
# model
net = DQN(
observation_space.shape[2],
observation_space.shape[1],
observation_space.shape[0],
device=args.device
).to(args.device)
actor = ActorProb(
net, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
net2 = DQN(
observation_space.shape[2],
observation_space.shape[1],
observation_space.shape[0],
device=args.device
).to(args.device)
critic = Critic(net2, device=args.device).to(args.device)
for m in set(actor.modules()).union(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
def dist(*logits):
return Independent(Normal(*logits), 1)
agent = PPOPolicy(
actor,
critic,
optim,
dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv,
# dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
action_space=env.action_space
)
agents.append(agent)
optims.append(optim)
policy = MultiAgentPolicyManager(
agents, env, action_scaling=True, action_bound_method='clip'
)
return policy, optims, env.agents
def train_agent(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[dict, BasePolicy]:
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=False # True
)
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
pass
def stop_fn(mean_rewards):
return False
def train_fn(epoch, env_step):
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
def test_fn(epoch, env_step):
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
def reward_metric(rews):
return rews[:, 0]
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
resume_from_log=args.resume
)
return result, policy
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")

View File

@ -0,0 +1,21 @@
import pprint
from pistonball import get_args, train_agent, watch
def test_piston_ball(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
# assert result["best_reward"] >= args.win_rate
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_piston_ball(get_args())

View File

@ -0,0 +1,21 @@
import pprint
from pistonball_continuous import get_args, train_agent, watch
def test_piston_ball_continuous(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
assert result["best_reward"] >= 30.0
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_piston_ball_continuous(get_args())

View File

@ -1,21 +1,21 @@
import pprint import pprint
from tic_tac_toe import get_args, train_agent, watch from tic_tac_toe import get_args, train_agent, watch
def test_tic_tac_toe(args=get_args()): def test_tic_tac_toe(args=get_args()):
if args.watch: if args.watch:
watch(args) watch(args)
return return
result, agent = train_agent(args) result, agent = train_agent(args)
assert result["best_reward"] >= args.win_rate assert result["best_reward"] >= args.win_rate
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
watch(args, agent) watch(args, agent)
if __name__ == '__main__': if __name__ == '__main__':
test_tic_tac_toe(get_args()) test_tic_tac_toe(get_args())

View File

@ -1,232 +1,241 @@
import argparse import argparse
import os import os
from copy import deepcopy from copy import deepcopy
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np import gym
import torch import numpy as np
from tic_tac_toe_env import TicTacToeEnv import torch
from torch.utils.tensorboard import SummaryWriter from pettingzoo.classic import tictactoe_v3
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import ( from tianshou.env import DummyVectorEnv
BasePolicy, from tianshou.env.pettingzoo_env import PettingZooEnv
DQNPolicy, from tianshou.policy import (
MultiAgentPolicyManager, BasePolicy,
RandomPolicy, DQNPolicy,
) MultiAgentPolicyManager,
from tianshou.trainer import offpolicy_trainer RandomPolicy,
from tianshou.utils import TensorboardLogger )
from tianshou.utils.net.common import Net from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626) def get_env():
parser.add_argument('--eps-test', type=float, default=0.05) return PettingZooEnv(tictactoe_v3.env())
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3) def get_parser() -> argparse.ArgumentParser:
parser.add_argument( parser = argparse.ArgumentParser()
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win' parser.add_argument('--seed', type=int, default=1626)
) parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--target-update-freq', type=int, default=320) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--epoch', type=int, default=20) parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--step-per-epoch', type=int, default=5000) parser.add_argument(
parser.add_argument('--step-per-collect', type=int, default=10) '--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
parser.add_argument('--update-per-step', type=float, default=0.1) )
parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--n-step', type=int, default=3)
parser.add_argument( parser.add_argument('--target-update-freq', type=int, default=320)
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128] parser.add_argument('--epoch', type=int, default=50)
) parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--training-num', type=int, default=10) parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--render', type=float, default=0.1) parser.add_argument(
parser.add_argument('--board-size', type=int, default=6) '--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
parser.add_argument('--win-size', type=int, default=4) )
parser.add_argument( parser.add_argument('--training-num', type=int, default=10)
'--win-rate', type=float, default=0.9, help='the expected winning rate' parser.add_argument('--test-num', type=int, default=100)
) parser.add_argument('--logdir', type=str, default='log')
parser.add_argument( parser.add_argument('--render', type=float, default=0.1)
'--watch', parser.add_argument(
default=False, '--win-rate',
action='store_true', type=float,
help='no training, ' default=0.6,
'watch the play of pre-trained models' help='the expected winning rate: Optimal policy can get 0.7'
) )
parser.add_argument( parser.add_argument(
'--agent-id', '--watch',
type=int, default=False,
default=2, action='store_true',
help='the learned agent plays as the' help='no training, '
' agent_id-th player. Choices are 1 and 2.' 'watch the play of pre-trained models'
) )
parser.add_argument( parser.add_argument(
'--resume-path', '--agent-id',
type=str, type=int,
default='', default=2,
help='the path of agent pth file ' help='the learned agent plays as the'
'for resuming from a pre-trained agent' ' agent_id-th player. Choices are 1 and 2.'
) )
parser.add_argument( parser.add_argument(
'--opponent-path', '--resume-path',
type=str, type=str,
default='', default='',
help='the path of opponent agent pth file ' help='the path of agent pth file '
'for resuming from a pre-trained agent' 'for resuming from a pre-trained agent'
) )
parser.add_argument( parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' '--opponent-path',
) type=str,
return parser default='',
help='the path of opponent agent pth file '
'for resuming from a pre-trained agent'
def get_args() -> argparse.Namespace: )
parser = get_parser() parser.add_argument(
return parser.parse_known_args()[0] '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
return parser
def get_agents(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None, def get_args() -> argparse.Namespace:
agent_opponent: Optional[BasePolicy] = None, parser = get_parser()
optim: Optional[torch.optim.Optimizer] = None, return parser.parse_known_args()[0]
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
env = TicTacToeEnv(args.board_size, args.win_size)
args.state_shape = env.observation_space.shape or env.observation_space.n def get_agents(
args.action_shape = env.action_space.shape or env.action_space.n args: argparse.Namespace = get_args(),
if agent_learn is None: agent_learn: Optional[BasePolicy] = None,
# model agent_opponent: Optional[BasePolicy] = None,
net = Net( optim: Optional[torch.optim.Optimizer] = None,
args.state_shape, ) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
args.action_shape, env = get_env()
hidden_sizes=args.hidden_sizes, observation_space = env.observation_space['observation'] if isinstance(
device=args.device env.observation_space, gym.spaces.Dict
).to(args.device) ) else env.observation_space
if optim is None: args.state_shape = observation_space.shape or observation_space.n
optim = torch.optim.Adam(net.parameters(), lr=args.lr) args.action_shape = env.action_space.shape or env.action_space.n
agent_learn = DQNPolicy( if agent_learn is None:
net, # model
optim, net = Net(
args.gamma, args.state_shape,
args.n_step, args.action_shape,
target_update_freq=args.target_update_freq hidden_sizes=args.hidden_sizes,
) device=args.device
if args.resume_path: ).to(args.device)
agent_learn.load_state_dict(torch.load(args.resume_path)) if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
if agent_opponent is None: agent_learn = DQNPolicy(
if args.opponent_path: net,
agent_opponent = deepcopy(agent_learn) optim,
agent_opponent.load_state_dict(torch.load(args.opponent_path)) args.gamma,
else: args.n_step,
agent_opponent = RandomPolicy() target_update_freq=args.target_update_freq
)
if args.agent_id == 1: if args.resume_path:
agents = [agent_learn, agent_opponent] agent_learn.load_state_dict(torch.load(args.resume_path))
else:
agents = [agent_opponent, agent_learn] if agent_opponent is None:
policy = MultiAgentPolicyManager(agents) if args.opponent_path:
return policy, optim agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
def train_agent( agent_opponent = RandomPolicy()
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None, if args.agent_id == 1:
agent_opponent: Optional[BasePolicy] = None, agents = [agent_learn, agent_opponent]
optim: Optional[torch.optim.Optimizer] = None, else:
) -> Tuple[dict, BasePolicy]: agents = [agent_opponent, agent_learn]
policy = MultiAgentPolicyManager(agents, env)
def env_func(): return policy, optim, env.agents
return TicTacToeEnv(args.board_size, args.win_size)
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)]) def train_agent(
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)]) args: argparse.Namespace = get_args(),
# seed agent_learn: Optional[BasePolicy] = None,
np.random.seed(args.seed) agent_opponent: Optional[BasePolicy] = None,
torch.manual_seed(args.seed) optim: Optional[torch.optim.Optimizer] = None,
train_envs.seed(args.seed) ) -> Tuple[dict, BasePolicy]:
test_envs.seed(args.seed)
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
policy, optim = get_agents( test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim # seed
) np.random.seed(args.seed)
torch.manual_seed(args.seed)
# collector train_envs.seed(args.seed)
train_collector = Collector( test_envs.seed(args.seed)
policy,
train_envs, policy, optim, agents = get_agents(
VectorReplayBuffer(args.buffer_size, len(train_envs)), args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
exploration_noise=True )
)
test_collector = Collector(policy, test_envs, exploration_noise=True) # collector
# policy.set_eps(1) train_collector = Collector(
train_collector.collect(n_step=args.batch_size * args.training_num) policy,
# log train_envs,
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') VectorReplayBuffer(args.buffer_size, len(train_envs)),
writer = SummaryWriter(log_path) exploration_noise=True
writer.add_text("args", str(args)) )
logger = TensorboardLogger(writer) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
def save_fn(policy): train_collector.collect(n_step=args.batch_size * args.training_num)
if hasattr(args, 'model_save_path'): # log
model_save_path = args.model_save_path log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
else: writer = SummaryWriter(log_path)
model_save_path = os.path.join( writer.add_text("args", str(args))
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth' logger = TensorboardLogger(writer)
)
torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path) def save_fn(policy):
if hasattr(args, 'model_save_path'):
def stop_fn(mean_rewards): model_save_path = args.model_save_path
return mean_rewards >= args.win_rate else:
model_save_path = os.path.join(
def train_fn(epoch, env_step): args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
policy.policies[args.agent_id - 1].set_eps(args.eps_train) )
torch.save(
def test_fn(epoch, env_step): policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path
policy.policies[args.agent_id - 1].set_eps(args.eps_test) )
def reward_metric(rews): def stop_fn(mean_rewards):
return rews[:, args.agent_id - 1] return mean_rewards >= args.win_rate
# trainer def train_fn(epoch, env_step):
result = offpolicy_trainer( policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)
policy,
train_collector, def test_fn(epoch, env_step):
test_collector, policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
args.epoch,
args.step_per_epoch, def reward_metric(rews):
args.step_per_collect, return rews[:, args.agent_id - 1]
args.test_num,
args.batch_size, # trainer
train_fn=train_fn, result = offpolicy_trainer(
test_fn=test_fn, policy,
stop_fn=stop_fn, train_collector,
save_fn=save_fn, test_collector,
update_per_step=args.update_per_step, args.epoch,
logger=logger, args.step_per_epoch,
test_in_train=False, args.step_per_collect,
reward_metric=reward_metric args.test_num,
) args.batch_size,
train_fn=train_fn,
return result, policy.policies[args.agent_id - 1] test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
def watch( update_per_step=args.update_per_step,
args: argparse.Namespace = get_args(), logger=logger,
agent_learn: Optional[BasePolicy] = None, test_in_train=False,
agent_opponent: Optional[BasePolicy] = None, reward_metric=reward_metric
) -> None: )
env = TicTacToeEnv(args.board_size, args.win_size)
policy, optim = get_agents( return result, policy.policies[agents[args.agent_id - 1]]
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
policy.eval() def watch(
policy.policies[args.agent_id - 1].set_eps(args.eps_test) args: argparse.Namespace = get_args(),
collector = Collector(policy, env, exploration_noise=True) agent_learn: Optional[BasePolicy] = None,
result = collector.collect(n_episode=1, render=args.render) agent_opponent: Optional[BasePolicy] = None,
rews, lens = result["rews"], result["lens"] ) -> None:
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") env = get_env()
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

View File

@ -1,6 +1,6 @@
"""Env package.""" """Env package."""
from tianshou.env.maenv import MultiAgentEnv from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env.venvs import ( from tianshou.env.venvs import (
BaseVectorEnv, BaseVectorEnv,
DummyVectorEnv, DummyVectorEnv,
@ -15,5 +15,5 @@ __all__ = [
"SubprocVectorEnv", "SubprocVectorEnv",
"ShmemVectorEnv", "ShmemVectorEnv",
"RayVectorEnv", "RayVectorEnv",
"MultiAgentEnv", "PettingZooEnv",
] ]

65
tianshou/env/maenv.py vendored
View File

@ -1,65 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
import gym
import numpy as np
class MultiAgentEnv(ABC, gym.Env):
"""The interface for multi-agent environments.
Multi-agent environments must be wrapped as
:class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
::
env = MultiAgentEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
act = policy(obs)
obs, rew, done, info = env.step(act)
env.close()
The available action's mask is set to 1, otherwise it is set to 0. Further
usage can be found at :ref:`marl_example`.
"""
def __init__(self) -> None:
pass
@abstractmethod
def reset(self) -> dict:
"""Reset the state.
Return the initial state, first agent_id, and the initial action set,
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
"""
pass
@abstractmethod
def step(
self, action: np.ndarray
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of the environments dynamics.
When the end of episode is reached, you are responsible for calling
reset() to reset the environments state.
Accept action and return a tuple (obs, rew, done, info).
:param numpy.ndarray action: action provided by a agent.
:return: A tuple including four items:
* ``obs`` a dict containing obs, agent_id, and mask, which means \
that it is the ``agent_id`` player's turn to play with ``obs``\
observation and ``mask``.
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions. Depending on the specific environment, this \
can be either a scalar reward for current agent or a vector \
reward for all the agents.
* ``done`` a numpy.ndarray, whether the episode has ended, in \
which case further step() calls will return undefined results
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
information (helpful for debugging, and sometimes learning)
"""
pass

112
tianshou/env/pettingzoo_env.py vendored Normal file
View File

@ -0,0 +1,112 @@
from abc import ABC
from typing import Any, Dict, List, Tuple
import gym.spaces
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper
class PettingZooEnv(AECEnv, gym.Env, ABC):
"""The interface for petting zoo environments.
Multi-agent environments must be wrapped as
:class:`~tianshou.env.PettingZooEnv`. Here is the usage:
::
env = PettingZooEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
action = policy(obs)
obs, rew, done, info = env.step(action)
env.close()
The available action's mask is set to True, otherwise it is set to False.
Further usage can be found at :ref:`marl_example`.
"""
def __init__(self, env: BaseWrapper):
super().__init__()
self.env = env
# agent idx list
self.agents = self.env.possible_agents
self.agent_idx = {}
for i, agent_id in enumerate(self.agents):
self.agent_idx[agent_id] = i
# Get dictionaries of obs_spaces and act_spaces
self.observation_spaces = self.env.observation_spaces
self.action_spaces = self.env.action_spaces
self.rewards = [0] * len(self.agents)
# Get first observation space, assuming all agents have equal space
self.observation_space: Any = self.observation_space(self.agents[0])
# Get first action space, assuming all agents have equal space
self.action_space: Any = self.action_space(self.agents[0])
assert all(self.env.observation_space(agent) == self.observation_space
for agent in self.agents), \
"Observation spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_observations wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_observations(env)`"
assert all(self.env.action_space(agent) == self.action_space
for agent in self.agents), \
"Action spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_action_space wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_action_space(env)`"
self.reset()
def reset(self) -> dict:
self.env.reset()
observation = self.env.observe(self.env.agent_selection)
if isinstance(observation, dict) and 'action_mask' in observation:
return {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
return {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
}
else:
return {'agent_id': self.env.agent_selection, 'obs': observation}
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
self.env.step(action)
observation, rew, done, info = self.env.last()
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
obs = {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
}
else:
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info
def close(self) -> None:
self.env.close()
def seed(self, seed: Any = None) -> None:
self.env.seed(seed)
def render(self, mode: str = "human") -> Any:
return self.env.render(mode)

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import gym import gym
import numpy as np import numpy as np
import pettingzoo
from tianshou.env.worker import ( from tianshou.env.worker import (
DummyEnvWorker, DummyEnvWorker,
@ -364,7 +365,10 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
""" """
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def __init__(
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
**kwargs: Any
) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs) super().__init__(env_fns, DummyEnvWorker, **kwargs)

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -16,21 +17,29 @@ class MultiAgentPolicyManager(BasePolicy):
:ref:`marl_example` can help you better understand this procedure. :ref:`marl_example` can help you better understand this procedure.
""" """
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: def __init__(
super().__init__(**kwargs) self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any
self.policies = policies ) -> None:
super().__init__(action_space=env.action_space, **kwargs)
assert (
len(policies) == len(env.agents)
), "One policy must be assigned for each agent."
self.agent_idx = env.agent_idx
for i, policy in enumerate(policies): for i, policy in enumerate(policies):
# agent_id 0 is reserved for the environment proxy # agent_id 0 is reserved for the environment proxy
# (this MultiAgentPolicyManager) # (this MultiAgentPolicyManager)
policy.set_agent_id(i + 1) policy.set_agent_id(env.agents[i])
self.policies = dict(zip(env.agents, policies))
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
"""Replace the "agent_id"th policy in this manager.""" """Replace the "agent_id"th policy in this manager."""
self.policies[agent_id - 1] = policy
policy.set_agent_id(agent_id) policy.set_agent_id(agent_id)
self.policies[agent_id] = policy
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's process_fn. """Dispatch batch data from obs.agent_id to every policy's process_fn.
@ -45,18 +54,21 @@ class MultiAgentPolicyManager(BasePolicy):
# Since we do not override buffer.__setattr__, here we use _meta to # Since we do not override buffer.__setattr__, here we use _meta to
# change buffer.rew, otherwise buffer.rew = Batch() has no effect. # change buffer.rew, otherwise buffer.rew = Batch() has no effect.
save_rew, buffer._meta.rew = buffer.rew, Batch() save_rew, buffer._meta.rew = buffer.rew, Batch()
for policy in self.policies: for agent, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
if len(agent_index) == 0: if len(agent_index) == 0:
results[f"agent_{policy.agent_id}"] = Batch() results[agent] = Batch()
continue continue
tmp_batch, tmp_indices = batch[agent_index], indices[agent_index] tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
if has_rew: if has_rew:
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
buffer._meta.rew = save_rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
results[f"agent_{policy.agent_id}"] = policy.process_fn( if not hasattr(tmp_batch.obs, "mask"):
tmp_batch, buffer, tmp_indices if hasattr(tmp_batch.obs, 'obs'):
) tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, 'obs'):
tmp_batch.obs_next = tmp_batch.obs_next.obs
results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew if has_rew: # restore from save_rew
buffer._meta.rew = save_rew buffer._meta.rew = save_rew
return Batch(results) return Batch(results)
@ -64,8 +76,8 @@ class MultiAgentPolicyManager(BasePolicy):
def exploration_noise(self, act: Union[np.ndarray, Batch], def exploration_noise(self, act: Union[np.ndarray, Batch],
batch: Batch) -> Union[np.ndarray, Batch]: batch: Batch) -> Union[np.ndarray, Batch]:
"""Add exploration noise from sub-policy onto act.""" """Add exploration noise from sub-policy onto act."""
for policy in self.policies: for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0: if len(agent_index) == 0:
continue continue
act[agent_index] = policy.exploration_noise( act[agent_index] = policy.exploration_noise(
@ -104,7 +116,7 @@ class MultiAgentPolicyManager(BasePolicy):
""" """
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch], results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
Batch]] = [] Batch]] = []
for policy in self.policies: for agent_id, policy in self.policies.items():
# This part of code is difficult to understand. # This part of code is difficult to understand.
# Let's follow an example with two agents # Let's follow an example with two agents
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6) # batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
@ -112,7 +124,7 @@ class MultiAgentPolicyManager(BasePolicy):
# agent_index for agent 1 is [0, 2, 4] # agent_index for agent 1 is [0, 2, 4]
# agent_index for agent 2 is [1, 3, 5] # agent_index for agent 2 is [1, 3, 5]
# we separate the transition of each agent according to agent_id # we separate the transition of each agent according to agent_id
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0: if len(agent_index) == 0:
# (has_data, agent_index, out, act, state) # (has_data, agent_index, out, act, state)
results.append((False, np.array([-1]), Batch(), Batch(), Batch())) results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
@ -120,11 +132,15 @@ class MultiAgentPolicyManager(BasePolicy):
tmp_batch = batch[agent_index] tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray): if isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray. # reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"):
if hasattr(tmp_batch.obs, 'obs'):
tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, 'obs'):
tmp_batch.obs_next = tmp_batch.obs_next.obs
out = policy( out = policy(
batch=tmp_batch, batch=tmp_batch,
state=None if state is None else state["agent_" + state=None if state is None else state[agent_id],
str(policy.agent_id)],
**kwargs **kwargs
) )
act = out.act act = out.act
@ -141,12 +157,12 @@ class MultiAgentPolicyManager(BasePolicy):
] ]
) )
state_dict, out_dict = {}, {} state_dict, out_dict = {}, {}
for policy, (has_data, agent_index, out, act, for (agent_id, _), (has_data, agent_index, out, act,
state) in zip(self.policies, results): state) in zip(self.policies.items(), results):
if has_data: if has_data:
holder.act[agent_index] = act holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state state_dict[agent_id] = state
out_dict["agent_" + str(policy.agent_id)] = out out_dict[agent_id] = out
holder["out"] = out_dict holder["out"] = out_dict
holder["state"] = state_dict holder["state"] = state_dict
return holder return holder
@ -168,10 +184,10 @@ class MultiAgentPolicyManager(BasePolicy):
} }
""" """
results = {} results = {}
for policy in self.policies: for agent_id, policy in self.policies.items():
data = batch[f"agent_{policy.agent_id}"] data = batch[agent_id]
if not data.is_empty(): if not data.is_empty():
out = policy.learn(batch=data, **kwargs) out = policy.learn(batch=data, **kwargs)
for k, v in out.items(): for k, v in out.items():
results["agent_" + str(policy.agent_id) + "/" + k] = v results[agent_id + "/" + k] = v
return results return results