Pettingzoo support (#494)

Co-authored-by: Rodrigo de Lazcano <r.l.p.v96@gmail.com>
Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
Mohammad Mahdi Rahimi 2022-02-15 17:56:45 +03:00 committed by GitHub
parent d85bc19269
commit c7e2e56fac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 933 additions and 585 deletions

View File

@ -55,6 +55,7 @@ setup(
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"pettingzoo>=1.12,<=1.13",
],
extras_require={
"dev": [
@ -74,6 +75,9 @@ setup(
"pydocstyle",
"doc8",
"scipy",
"pillow",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
],
"atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"],

View File

@ -1,87 +0,0 @@
import os
import pprint
from copy import deepcopy
import numpy as np
from tic_tac_toe import get_agents, get_parser, train_agent, watch
from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv
from tianshou.policy import RandomPolicy
from tianshou.utils import TensorboardLogger
def get_args():
parser = get_parser()
parser.add_argument('--self_play_round', type=int, default=20)
args = parser.parse_known_args()[0]
return args
def gomoku(args=get_args()):
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
if args.watch:
watch(args)
return
policy, optim = get_agents(args)
agent_learn = policy.policies[args.agent_id - 1]
agent_opponent = policy.policies[2 - args.agent_id]
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
writer = SummaryWriter(log_path)
args.logger = TensorboardLogger(writer)
opponent_pool = [agent_opponent]
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for round in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
for opponent in opponent_pool:
policy.replace_policy(opponent, 3 - args.agent_id)
test_collector = Collector(policy, test_envs)
results = test_collector.collect(n_episode=100)
rews.append(results['rews'].mean())
rews = np.array(rews)
# weight opponent by their difficulty level
rews = np.exp(-rews * 10.0)
rews /= np.sum(rews)
total_epoch = args.epoch
args.epoch = 1
for epoch in range(total_epoch):
# sample one opponent
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
print(f'selection probability {rews.tolist()}')
print(f'selected opponent {opp_id}')
opponent = opponent_pool[opp_id.item(0)]
agent = RandomPolicy()
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
)
result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
)
print(f'round_{round}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
opponent_pool.append(learnt_agent)
args.epoch = total_epoch
if __name__ == '__main__':
# Let's watch its performance!
opponent = opponent_pool[-2]
watch(args, agent_learn, opponent)
if __name__ == '__main__':
gomoku(get_args())

View File

@ -1,148 +0,0 @@
from functools import partial
from typing import Optional, Tuple
import gym
import numpy as np
from tianshou.env import MultiAgentEnv
class TicTacToeEnv(MultiAgentEnv):
"""This is a simple implementation of the Tic-Tac-Toe game, where two
agents play against each other.
The implementation is intended to show how to wrap an environment to
satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`.
:param size: the size of the board (square board)
:param win_size: how many units in a row is considered to win
"""
def __init__(self, size: int = 3, win_size: int = 3):
super().__init__()
assert size > 0, f'board size should be positive, but got {size}'
self.size = size
assert win_size > 0, f'win-size should be positive, but got {win_size}'
self.win_size = win_size
assert win_size <= size, f'win-size {win_size} should not ' \
f'be larger than board size {size}'
self.convolve_kernel = np.ones(win_size)
self.observation_space = gym.spaces.Box(
low=-1.0, high=1.0, shape=(size, size), dtype=np.float32
)
self.action_space = gym.spaces.Discrete(size * size)
self.current_board = None
self.current_agent = None
self._last_move = None
self.step_num = None
def reset(self) -> dict:
self.current_board = np.zeros((self.size, self.size), dtype=np.int32)
self.current_agent = 1
self._last_move = (-1, -1)
self.step_num = 0
return {
'agent_id': self.current_agent,
'obs': np.array(self.current_board),
'mask': self.current_board.flatten() == 0
}
def step(self, action: [int,
np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
if self.current_agent is None:
raise ValueError("calling step() of unreset environment is prohibited!")
assert 0 <= action < self.size * self.size
assert self.current_board.item(action) == 0
_current_agent = self.current_agent
self._move(action)
mask = self.current_board.flatten() == 0
is_win, is_opponent_win = False, False
is_win = self._test_win()
# the game is over when one wins or there is only one empty place
done = is_win
if sum(mask) == 1:
done = True
self._move(np.where(mask)[0][0])
is_opponent_win = self._test_win()
if is_win:
reward = 1
elif is_opponent_win:
reward = -1
else:
reward = 0
obs = {
'agent_id': self.current_agent,
'obs': np.array(self.current_board),
'mask': mask
}
rew_agent_1 = reward if _current_agent == 1 else (-reward)
rew_agent_2 = reward if _current_agent == 2 else (-reward)
vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32)
if done:
self.current_agent = None
return obs, vec_rew, np.array(done), {}
def _move(self, action):
row, col = action // self.size, action % self.size
if self.current_agent == 1:
self.current_board[row, col] = 1
else:
self.current_board[row, col] = -1
self.current_agent = 3 - self.current_agent
self._last_move = (row, col)
self.step_num += 1
def _test_win(self):
"""test if someone wins by checking the situation around last move"""
row, col = self._last_move
rboard = self.current_board[row, :]
cboard = self.current_board[:, col]
current = self.current_board[row, col]
rightup = [
self.current_board[row - i, col + i] for i in range(1, self.size - col)
if row - i >= 0
]
leftdown = [
self.current_board[row + i, col - i] for i in range(1, col + 1)
if row + i < self.size
]
rdiag = np.array(leftdown[::-1] + [current] + rightup)
rightdown = [
self.current_board[row + i, col + i] for i in range(1, self.size - col)
if row + i < self.size
]
leftup = [
self.current_board[row - i, col - i] for i in range(1, col + 1)
if row - i >= 0
]
diag = np.array(leftup[::-1] + [current] + rightdown)
results = [
np.convolve(k, self.convolve_kernel, mode='valid')
for k in (rboard, cboard, rdiag, diag)
]
return any([(np.abs(x) == self.win_size).any() for x in results])
def seed(self, seed: Optional[int] = None) -> int:
pass
def render(self, **kwargs) -> None:
print(f'board (step {self.step_num}):')
pad = '==='
top = pad + '=' * (2 * self.size - 1) + pad
print(top)
def f(i, data):
j, number = data
last_move = i == self._last_move[0] and j == self._last_move[1]
if number == 1:
return 'X' if last_move else 'x'
if number == -1:
return 'O' if last_move else 'o'
return '_'
for i, row in enumerate(self.current_board):
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
print(top)
def close(self) -> None:
pass

View File

@ -0,0 +1,185 @@
import argparse
import os
from typing import List, Optional, Tuple
import gym
import numpy as np
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=2000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument(
'--n-pistons',
type=int,
default=3,
help='Number of pistons(agents) in the env'
)
parser.add_argument('--n-step', type=int, default=100)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=3)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.0)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_env(args: argparse.Namespace = get_args()):
return PettingZooEnv(pistonball_v4.env(continuous=False, n_pistons=args.n_pistons))
def get_agents(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
env.observation_space, gym.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if agents is None:
agents = []
optims = []
for _ in range(args.n_pistons):
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
)
agents.append(agent)
optims.append(optim)
policy = MultiAgentPolicyManager(agents, env)
return policy, optims, env.agents
def train_agent(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[dict, BasePolicy]:
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
pass
def stop_fn(mean_rewards):
return False
def train_fn(epoch, env_step):
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
def test_fn(epoch, env_step):
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
def reward_metric(rews):
return rews[:, 0]
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
update_per_step=args.update_per_step,
logger=logger,
test_in_train=False,
reward_metric=reward_metric
)
return result, policy
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")

View File

@ -0,0 +1,276 @@
import argparse
import os
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
import numpy as np
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
import torch
import torch.nn as nn
from torch.distributions import Independent, Normal
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.continuous import ActorProb, Critic
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
self.c = c
self.h = h
self.w = w
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten()
)
with torch.no_grad():
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x.reshape(-1, self.c, self.w, self.h)), state
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=2000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument(
'--n-pistons',
type=int,
default=3,
help='Number of pistons(agents) in the env'
)
parser.add_argument('--n-step', type=int, default=100)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--episode-per-collect', type=int, default=16)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=1000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=1000)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.25)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
parser.add_argument('--render', type=float, default=0.0)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_env(args: argparse.Namespace = get_args()):
return PettingZooEnv(pistonball_v4.env(continuous=True, n_pistons=args.n_pistons))
def get_agents(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
env.observation_space, gym.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
if agents is None:
agents = []
optims = []
for _ in range(args.n_pistons):
# model
net = DQN(
observation_space.shape[2],
observation_space.shape[1],
observation_space.shape[0],
device=args.device
).to(args.device)
actor = ActorProb(
net, args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
net2 = DQN(
observation_space.shape[2],
observation_space.shape[1],
observation_space.shape[0],
device=args.device
).to(args.device)
critic = Critic(net2, device=args.device).to(args.device)
for m in set(actor.modules()).union(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr
)
def dist(*logits):
return Independent(Normal(*logits), 1)
agent = PPOPolicy(
actor,
critic,
optim,
dist,
discount_factor=args.gamma,
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv,
# dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
action_space=env.action_space
)
agents.append(agent)
optims.append(optim)
policy = MultiAgentPolicyManager(
agents, env, action_scaling=True, action_bound_method='clip'
)
return policy, optims, env.agents
def train_agent(
args: argparse.Namespace = get_args(),
agents: Optional[List[BasePolicy]] = None,
optims: Optional[List[torch.optim.Optimizer]] = None,
) -> Tuple[dict, BasePolicy]:
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=False # True
)
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
pass
def stop_fn(mean_rewards):
return False
def train_fn(epoch, env_step):
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
def test_fn(epoch, env_step):
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
def reward_metric(rews):
return rews[:, 0]
# trainer
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.repeat_per_collect,
args.test_num,
args.batch_size,
episode_per_collect=args.episode_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
resume_from_log=args.resume
)
return result, policy
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")

View File

@ -0,0 +1,21 @@
import pprint
from pistonball import get_args, train_agent, watch
def test_piston_ball(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
# assert result["best_reward"] >= args.win_rate
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_piston_ball(get_args())

View File

@ -0,0 +1,21 @@
import pprint
from pistonball_continuous import get_args, train_agent, watch
def test_piston_ball_continuous(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
assert result["best_reward"] >= 30.0
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_piston_ball_continuous(get_args())

View File

@ -1,21 +1,21 @@
import pprint
from tic_tac_toe import get_args, train_agent, watch
def test_tic_tac_toe(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
assert result["best_reward"] >= args.win_rate
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_tic_tac_toe(get_args())
import pprint
from tic_tac_toe import get_args, train_agent, watch
def test_tic_tac_toe(args=get_args()):
if args.watch:
watch(args)
return
result, agent = train_agent(args)
assert result["best_reward"] >= args.win_rate
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == '__main__':
test_tic_tac_toe(get_args())

View File

@ -1,232 +1,241 @@
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple
import numpy as np
import torch
from tic_tac_toe_env import TicTacToeEnv
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import (
BasePolicy,
DQNPolicy,
MultiAgentPolicyManager,
RandomPolicy,
)
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument(
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument('--board-size', type=int, default=6)
parser.add_argument('--win-size', type=int, default=4)
parser.add_argument(
'--win-rate', type=float, default=0.9, help='the expected winning rate'
)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--agent-id',
type=int,
default=2,
help='the learned agent plays as the'
' agent_id-th player. Choices are 1 and 2.'
)
parser.add_argument(
'--resume-path',
type=str,
default='',
help='the path of agent pth file '
'for resuming from a pre-trained agent'
)
parser.add_argument(
'--opponent-path',
type=str,
default='',
help='the path of opponent agent pth file '
'for resuming from a pre-trained agent'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_agents(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
env = TicTacToeEnv(args.board_size, args.win_size)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
)
if args.resume_path:
agent_learn.load_state_dict(torch.load(args.resume_path))
if agent_opponent is None:
if args.opponent_path:
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy()
if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
else:
agents = [agent_opponent, agent_learn]
policy = MultiAgentPolicyManager(agents)
return policy, optim
def train_agent(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
def env_func():
return TicTacToeEnv(args.board_size, args.win_size)
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
if hasattr(args, 'model_save_path'):
model_save_path = args.model_save_path
else:
model_save_path = os.path.join(
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
)
torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path)
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate
def train_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
def reward_metric(rews):
return rews[:, args.agent_id - 1]
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
update_per_step=args.update_per_step,
logger=logger,
test_in_train=False,
reward_metric=reward_metric
)
return result, policy.policies[args.agent_id - 1]
def watch(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = TicTacToeEnv(args.board_size, args.win_size)
policy, optim = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
policy.eval()
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
import argparse
import os
from copy import deepcopy
from typing import Optional, Tuple
import gym
import numpy as np
import torch
from pettingzoo.classic import tictactoe_v3
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import (
BasePolicy,
DQNPolicy,
MultiAgentPolicyManager,
RandomPolicy,
)
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_env():
return PettingZooEnv(tictactoe_v3.env())
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument(
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=50)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument(
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument(
'--win-rate',
type=float,
default=0.6,
help='the expected winning rate: Optimal policy can get 0.7'
)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='no training, '
'watch the play of pre-trained models'
)
parser.add_argument(
'--agent-id',
type=int,
default=2,
help='the learned agent plays as the'
' agent_id-th player. Choices are 1 and 2.'
)
parser.add_argument(
'--resume-path',
type=str,
default='',
help='the path of agent pth file '
'for resuming from a pre-trained agent'
)
parser.add_argument(
'--opponent-path',
type=str,
default='',
help='the path of opponent agent pth file '
'for resuming from a pre-trained agent'
)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
return parser
def get_args() -> argparse.Namespace:
parser = get_parser()
return parser.parse_known_args()[0]
def get_agents(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
env = get_env()
observation_space = env.observation_space['observation'] if isinstance(
env.observation_space, gym.spaces.Dict
) else env.observation_space
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
net = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device
).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(
net,
optim,
args.gamma,
args.n_step,
target_update_freq=args.target_update_freq
)
if args.resume_path:
agent_learn.load_state_dict(torch.load(args.resume_path))
if agent_opponent is None:
if args.opponent_path:
agent_opponent = deepcopy(agent_learn)
agent_opponent.load_state_dict(torch.load(args.opponent_path))
else:
agent_opponent = RandomPolicy()
if args.agent_id == 1:
agents = [agent_learn, agent_opponent]
else:
agents = [agent_opponent, agent_learn]
policy = MultiAgentPolicyManager(agents, env)
return policy, optim, env.agents
def train_agent(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
optim: Optional[torch.optim.Optimizer] = None,
) -> Tuple[dict, BasePolicy]:
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
if hasattr(args, 'model_save_path'):
model_save_path = args.model_save_path
else:
model_save_path = os.path.join(
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
)
torch.save(
policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path
)
def stop_fn(mean_rewards):
return mean_rewards >= args.win_rate
def train_fn(epoch, env_step):
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)
def test_fn(epoch, env_step):
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
def reward_metric(rews):
return rews[:, args.agent_id - 1]
# trainer
result = offpolicy_trainer(
policy,
train_collector,
test_collector,
args.epoch,
args.step_per_epoch,
args.step_per_collect,
args.test_num,
args.batch_size,
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
save_fn=save_fn,
update_per_step=args.update_per_step,
logger=logger,
test_in_train=False,
reward_metric=reward_metric
)
return result, policy.policies[agents[args.agent_id - 1]]
def watch(
args: argparse.Namespace = get_args(),
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = get_env()
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

View File

@ -1,6 +1,6 @@
"""Env package."""
from tianshou.env.maenv import MultiAgentEnv
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.env.venvs import (
BaseVectorEnv,
DummyVectorEnv,
@ -15,5 +15,5 @@ __all__ = [
"SubprocVectorEnv",
"ShmemVectorEnv",
"RayVectorEnv",
"MultiAgentEnv",
"PettingZooEnv",
]

65
tianshou/env/maenv.py vendored
View File

@ -1,65 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple
import gym
import numpy as np
class MultiAgentEnv(ABC, gym.Env):
"""The interface for multi-agent environments.
Multi-agent environments must be wrapped as
:class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
::
env = MultiAgentEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
act = policy(obs)
obs, rew, done, info = env.step(act)
env.close()
The available action's mask is set to 1, otherwise it is set to 0. Further
usage can be found at :ref:`marl_example`.
"""
def __init__(self) -> None:
pass
@abstractmethod
def reset(self) -> dict:
"""Reset the state.
Return the initial state, first agent_id, and the initial action set,
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
"""
pass
@abstractmethod
def step(
self, action: np.ndarray
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of the environments dynamics.
When the end of episode is reached, you are responsible for calling
reset() to reset the environments state.
Accept action and return a tuple (obs, rew, done, info).
:param numpy.ndarray action: action provided by a agent.
:return: A tuple including four items:
* ``obs`` a dict containing obs, agent_id, and mask, which means \
that it is the ``agent_id`` player's turn to play with ``obs``\
observation and ``mask``.
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
previous actions. Depending on the specific environment, this \
can be either a scalar reward for current agent or a vector \
reward for all the agents.
* ``done`` a numpy.ndarray, whether the episode has ended, in \
which case further step() calls will return undefined results
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
information (helpful for debugging, and sometimes learning)
"""
pass

112
tianshou/env/pettingzoo_env.py vendored Normal file
View File

@ -0,0 +1,112 @@
from abc import ABC
from typing import Any, Dict, List, Tuple
import gym.spaces
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper
class PettingZooEnv(AECEnv, gym.Env, ABC):
"""The interface for petting zoo environments.
Multi-agent environments must be wrapped as
:class:`~tianshou.env.PettingZooEnv`. Here is the usage:
::
env = PettingZooEnv(...)
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
action = policy(obs)
obs, rew, done, info = env.step(action)
env.close()
The available action's mask is set to True, otherwise it is set to False.
Further usage can be found at :ref:`marl_example`.
"""
def __init__(self, env: BaseWrapper):
super().__init__()
self.env = env
# agent idx list
self.agents = self.env.possible_agents
self.agent_idx = {}
for i, agent_id in enumerate(self.agents):
self.agent_idx[agent_id] = i
# Get dictionaries of obs_spaces and act_spaces
self.observation_spaces = self.env.observation_spaces
self.action_spaces = self.env.action_spaces
self.rewards = [0] * len(self.agents)
# Get first observation space, assuming all agents have equal space
self.observation_space: Any = self.observation_space(self.agents[0])
# Get first action space, assuming all agents have equal space
self.action_space: Any = self.action_space(self.agents[0])
assert all(self.env.observation_space(agent) == self.observation_space
for agent in self.agents), \
"Observation spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_observations wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_observations(env)`"
assert all(self.env.action_space(agent) == self.action_space
for agent in self.agents), \
"Action spaces for all agents must be identical. Perhaps " \
"SuperSuit's pad_action_space wrapper can help (useage: " \
"`supersuit.aec_wrappers.pad_action_space(env)`"
self.reset()
def reset(self) -> dict:
self.env.reset()
observation = self.env.observe(self.env.agent_selection)
if isinstance(observation, dict) and 'action_mask' in observation:
return {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
return {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
}
else:
return {'agent_id': self.env.agent_selection, 'obs': observation}
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
self.env.step(action)
observation, rew, done, info = self.env.last()
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
obs = {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
}
else:
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info
def close(self) -> None:
self.env.close()
def seed(self, seed: Any = None) -> None:
self.env.seed(seed)
def render(self, mode: str = "human") -> Any:
return self.env.render(mode)

View File

@ -2,6 +2,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import gym
import numpy as np
import pettingzoo
from tianshou.env.worker import (
DummyEnvWorker,
@ -364,7 +365,10 @@ class DummyVectorEnv(BaseVectorEnv):
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
def __init__(
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
**kwargs: Any
) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
from tianshou.data import Batch, ReplayBuffer
from tianshou.env.pettingzoo_env import PettingZooEnv
from tianshou.policy import BasePolicy
@ -16,21 +17,29 @@ class MultiAgentPolicyManager(BasePolicy):
:ref:`marl_example` can help you better understand this procedure.
"""
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
super().__init__(**kwargs)
self.policies = policies
def __init__(
self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any
) -> None:
super().__init__(action_space=env.action_space, **kwargs)
assert (
len(policies) == len(env.agents)
), "One policy must be assigned for each agent."
self.agent_idx = env.agent_idx
for i, policy in enumerate(policies):
# agent_id 0 is reserved for the environment proxy
# (this MultiAgentPolicyManager)
policy.set_agent_id(i + 1)
policy.set_agent_id(env.agents[i])
self.policies = dict(zip(env.agents, policies))
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
"""Replace the "agent_id"th policy in this manager."""
self.policies[agent_id - 1] = policy
policy.set_agent_id(agent_id)
self.policies[agent_id] = policy
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
@ -45,18 +54,21 @@ class MultiAgentPolicyManager(BasePolicy):
# Since we do not override buffer.__setattr__, here we use _meta to
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
save_rew, buffer._meta.rew = buffer.rew, Batch()
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
for agent, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
if len(agent_index) == 0:
results[f"agent_{policy.agent_id}"] = Batch()
results[agent] = Batch()
continue
tmp_batch, tmp_indices = batch[agent_index], indices[agent_index]
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
if has_rew:
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
results[f"agent_{policy.agent_id}"] = policy.process_fn(
tmp_batch, buffer, tmp_indices
)
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
if not hasattr(tmp_batch.obs, "mask"):
if hasattr(tmp_batch.obs, 'obs'):
tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, 'obs'):
tmp_batch.obs_next = tmp_batch.obs_next.obs
results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew
buffer._meta.rew = save_rew
return Batch(results)
@ -64,8 +76,8 @@ class MultiAgentPolicyManager(BasePolicy):
def exploration_noise(self, act: Union[np.ndarray, Batch],
batch: Batch) -> Union[np.ndarray, Batch]:
"""Add exploration noise from sub-policy onto act."""
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
continue
act[agent_index] = policy.exploration_noise(
@ -104,7 +116,7 @@ class MultiAgentPolicyManager(BasePolicy):
"""
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
Batch]] = []
for policy in self.policies:
for agent_id, policy in self.policies.items():
# This part of code is difficult to understand.
# Let's follow an example with two agents
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
@ -112,7 +124,7 @@ class MultiAgentPolicyManager(BasePolicy):
# agent_index for agent 1 is [0, 2, 4]
# agent_index for agent 2 is [1, 3, 5]
# we separate the transition of each agent according to agent_id
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
# (has_data, agent_index, out, act, state)
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
@ -120,11 +132,15 @@ class MultiAgentPolicyManager(BasePolicy):
tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"):
if hasattr(tmp_batch.obs, 'obs'):
tmp_batch.obs = tmp_batch.obs.obs
if hasattr(tmp_batch.obs_next, 'obs'):
tmp_batch.obs_next = tmp_batch.obs_next.obs
out = policy(
batch=tmp_batch,
state=None if state is None else state["agent_" +
str(policy.agent_id)],
state=None if state is None else state[agent_id],
**kwargs
)
act = out.act
@ -141,12 +157,12 @@ class MultiAgentPolicyManager(BasePolicy):
]
)
state_dict, out_dict = {}, {}
for policy, (has_data, agent_index, out, act,
state) in zip(self.policies, results):
for (agent_id, _), (has_data, agent_index, out, act,
state) in zip(self.policies.items(), results):
if has_data:
holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state
out_dict["agent_" + str(policy.agent_id)] = out
state_dict[agent_id] = state
out_dict[agent_id] = out
holder["out"] = out_dict
holder["state"] = state_dict
return holder
@ -168,10 +184,10 @@ class MultiAgentPolicyManager(BasePolicy):
}
"""
results = {}
for policy in self.policies:
data = batch[f"agent_{policy.agent_id}"]
for agent_id, policy in self.policies.items():
data = batch[agent_id]
if not data.is_empty():
out = policy.learn(batch=data, **kwargs)
for k, v in out.items():
results["agent_" + str(policy.agent_id) + "/" + k] = v
results[agent_id + "/" + k] = v
return results