Pettingzoo support (#494)
Co-authored-by: Rodrigo de Lazcano <r.l.p.v96@gmail.com> Co-authored-by: J K Terry <justinkterry@gmail.com>
This commit is contained in:
parent
d85bc19269
commit
c7e2e56fac
4
setup.py
4
setup.py
@ -55,6 +55,7 @@ setup(
|
|||||||
"torch>=1.4.0",
|
"torch>=1.4.0",
|
||||||
"numba>=0.51.0",
|
"numba>=0.51.0",
|
||||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||||
|
"pettingzoo>=1.12,<=1.13",
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
@ -74,6 +75,9 @@ setup(
|
|||||||
"pydocstyle",
|
"pydocstyle",
|
||||||
"doc8",
|
"doc8",
|
||||||
"scipy",
|
"scipy",
|
||||||
|
"pillow",
|
||||||
|
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||||
|
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||||
],
|
],
|
||||||
"atari": ["atari_py", "opencv-python"],
|
"atari": ["atari_py", "opencv-python"],
|
||||||
"mujoco": ["mujoco_py"],
|
"mujoco": ["mujoco_py"],
|
||||||
|
|||||||
@ -1,87 +0,0 @@
|
|||||||
import os
|
|
||||||
import pprint
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tic_tac_toe import get_agents, get_parser, train_agent, watch
|
|
||||||
from tic_tac_toe_env import TicTacToeEnv
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
|
|
||||||
from tianshou.data import Collector
|
|
||||||
from tianshou.env import DummyVectorEnv
|
|
||||||
from tianshou.policy import RandomPolicy
|
|
||||||
from tianshou.utils import TensorboardLogger
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = get_parser()
|
|
||||||
parser.add_argument('--self_play_round', type=int, default=20)
|
|
||||||
args = parser.parse_known_args()[0]
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def gomoku(args=get_args()):
|
|
||||||
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]
|
|
||||||
if args.watch:
|
|
||||||
watch(args)
|
|
||||||
return
|
|
||||||
|
|
||||||
policy, optim = get_agents(args)
|
|
||||||
agent_learn = policy.policies[args.agent_id - 1]
|
|
||||||
agent_opponent = policy.policies[2 - args.agent_id]
|
|
||||||
|
|
||||||
# log
|
|
||||||
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
|
|
||||||
writer = SummaryWriter(log_path)
|
|
||||||
args.logger = TensorboardLogger(writer)
|
|
||||||
|
|
||||||
opponent_pool = [agent_opponent]
|
|
||||||
|
|
||||||
def env_func():
|
|
||||||
return TicTacToeEnv(args.board_size, args.win_size)
|
|
||||||
|
|
||||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
|
||||||
for round in range(args.self_play_round):
|
|
||||||
rews = []
|
|
||||||
agent_learn.set_eps(0.0)
|
|
||||||
# compute the reward over previous learner
|
|
||||||
for opponent in opponent_pool:
|
|
||||||
policy.replace_policy(opponent, 3 - args.agent_id)
|
|
||||||
test_collector = Collector(policy, test_envs)
|
|
||||||
results = test_collector.collect(n_episode=100)
|
|
||||||
rews.append(results['rews'].mean())
|
|
||||||
rews = np.array(rews)
|
|
||||||
# weight opponent by their difficulty level
|
|
||||||
rews = np.exp(-rews * 10.0)
|
|
||||||
rews /= np.sum(rews)
|
|
||||||
total_epoch = args.epoch
|
|
||||||
args.epoch = 1
|
|
||||||
for epoch in range(total_epoch):
|
|
||||||
# sample one opponent
|
|
||||||
opp_id = np.random.choice(len(opponent_pool), size=1, p=rews)
|
|
||||||
print(f'selection probability {rews.tolist()}')
|
|
||||||
print(f'selected opponent {opp_id}')
|
|
||||||
opponent = opponent_pool[opp_id.item(0)]
|
|
||||||
agent = RandomPolicy()
|
|
||||||
# previous learner can only be used for forward
|
|
||||||
agent.forward = opponent.forward
|
|
||||||
args.model_save_path = os.path.join(
|
|
||||||
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
|
|
||||||
)
|
|
||||||
result, agent_learn = train_agent(
|
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
|
|
||||||
)
|
|
||||||
print(f'round_{round}_epoch_{epoch}')
|
|
||||||
pprint.pprint(result)
|
|
||||||
learnt_agent = deepcopy(agent_learn)
|
|
||||||
learnt_agent.set_eps(0.0)
|
|
||||||
opponent_pool.append(learnt_agent)
|
|
||||||
args.epoch = total_epoch
|
|
||||||
if __name__ == '__main__':
|
|
||||||
# Let's watch its performance!
|
|
||||||
opponent = opponent_pool[-2]
|
|
||||||
watch(args, agent_learn, opponent)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
gomoku(get_args())
|
|
||||||
@ -1,148 +0,0 @@
|
|||||||
from functools import partial
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tianshou.env import MultiAgentEnv
|
|
||||||
|
|
||||||
|
|
||||||
class TicTacToeEnv(MultiAgentEnv):
|
|
||||||
"""This is a simple implementation of the Tic-Tac-Toe game, where two
|
|
||||||
agents play against each other.
|
|
||||||
|
|
||||||
The implementation is intended to show how to wrap an environment to
|
|
||||||
satisfy the interface of :class:`~tianshou.env.MultiAgentEnv`.
|
|
||||||
|
|
||||||
:param size: the size of the board (square board)
|
|
||||||
:param win_size: how many units in a row is considered to win
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, size: int = 3, win_size: int = 3):
|
|
||||||
super().__init__()
|
|
||||||
assert size > 0, f'board size should be positive, but got {size}'
|
|
||||||
self.size = size
|
|
||||||
assert win_size > 0, f'win-size should be positive, but got {win_size}'
|
|
||||||
self.win_size = win_size
|
|
||||||
assert win_size <= size, f'win-size {win_size} should not ' \
|
|
||||||
f'be larger than board size {size}'
|
|
||||||
self.convolve_kernel = np.ones(win_size)
|
|
||||||
self.observation_space = gym.spaces.Box(
|
|
||||||
low=-1.0, high=1.0, shape=(size, size), dtype=np.float32
|
|
||||||
)
|
|
||||||
self.action_space = gym.spaces.Discrete(size * size)
|
|
||||||
self.current_board = None
|
|
||||||
self.current_agent = None
|
|
||||||
self._last_move = None
|
|
||||||
self.step_num = None
|
|
||||||
|
|
||||||
def reset(self) -> dict:
|
|
||||||
self.current_board = np.zeros((self.size, self.size), dtype=np.int32)
|
|
||||||
self.current_agent = 1
|
|
||||||
self._last_move = (-1, -1)
|
|
||||||
self.step_num = 0
|
|
||||||
return {
|
|
||||||
'agent_id': self.current_agent,
|
|
||||||
'obs': np.array(self.current_board),
|
|
||||||
'mask': self.current_board.flatten() == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def step(self, action: [int,
|
|
||||||
np.ndarray]) -> Tuple[dict, np.ndarray, np.ndarray, dict]:
|
|
||||||
if self.current_agent is None:
|
|
||||||
raise ValueError("calling step() of unreset environment is prohibited!")
|
|
||||||
assert 0 <= action < self.size * self.size
|
|
||||||
assert self.current_board.item(action) == 0
|
|
||||||
_current_agent = self.current_agent
|
|
||||||
self._move(action)
|
|
||||||
mask = self.current_board.flatten() == 0
|
|
||||||
is_win, is_opponent_win = False, False
|
|
||||||
is_win = self._test_win()
|
|
||||||
# the game is over when one wins or there is only one empty place
|
|
||||||
done = is_win
|
|
||||||
if sum(mask) == 1:
|
|
||||||
done = True
|
|
||||||
self._move(np.where(mask)[0][0])
|
|
||||||
is_opponent_win = self._test_win()
|
|
||||||
if is_win:
|
|
||||||
reward = 1
|
|
||||||
elif is_opponent_win:
|
|
||||||
reward = -1
|
|
||||||
else:
|
|
||||||
reward = 0
|
|
||||||
obs = {
|
|
||||||
'agent_id': self.current_agent,
|
|
||||||
'obs': np.array(self.current_board),
|
|
||||||
'mask': mask
|
|
||||||
}
|
|
||||||
rew_agent_1 = reward if _current_agent == 1 else (-reward)
|
|
||||||
rew_agent_2 = reward if _current_agent == 2 else (-reward)
|
|
||||||
vec_rew = np.array([rew_agent_1, rew_agent_2], dtype=np.float32)
|
|
||||||
if done:
|
|
||||||
self.current_agent = None
|
|
||||||
return obs, vec_rew, np.array(done), {}
|
|
||||||
|
|
||||||
def _move(self, action):
|
|
||||||
row, col = action // self.size, action % self.size
|
|
||||||
if self.current_agent == 1:
|
|
||||||
self.current_board[row, col] = 1
|
|
||||||
else:
|
|
||||||
self.current_board[row, col] = -1
|
|
||||||
self.current_agent = 3 - self.current_agent
|
|
||||||
self._last_move = (row, col)
|
|
||||||
self.step_num += 1
|
|
||||||
|
|
||||||
def _test_win(self):
|
|
||||||
"""test if someone wins by checking the situation around last move"""
|
|
||||||
row, col = self._last_move
|
|
||||||
rboard = self.current_board[row, :]
|
|
||||||
cboard = self.current_board[:, col]
|
|
||||||
current = self.current_board[row, col]
|
|
||||||
rightup = [
|
|
||||||
self.current_board[row - i, col + i] for i in range(1, self.size - col)
|
|
||||||
if row - i >= 0
|
|
||||||
]
|
|
||||||
leftdown = [
|
|
||||||
self.current_board[row + i, col - i] for i in range(1, col + 1)
|
|
||||||
if row + i < self.size
|
|
||||||
]
|
|
||||||
rdiag = np.array(leftdown[::-1] + [current] + rightup)
|
|
||||||
rightdown = [
|
|
||||||
self.current_board[row + i, col + i] for i in range(1, self.size - col)
|
|
||||||
if row + i < self.size
|
|
||||||
]
|
|
||||||
leftup = [
|
|
||||||
self.current_board[row - i, col - i] for i in range(1, col + 1)
|
|
||||||
if row - i >= 0
|
|
||||||
]
|
|
||||||
diag = np.array(leftup[::-1] + [current] + rightdown)
|
|
||||||
results = [
|
|
||||||
np.convolve(k, self.convolve_kernel, mode='valid')
|
|
||||||
for k in (rboard, cboard, rdiag, diag)
|
|
||||||
]
|
|
||||||
return any([(np.abs(x) == self.win_size).any() for x in results])
|
|
||||||
|
|
||||||
def seed(self, seed: Optional[int] = None) -> int:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def render(self, **kwargs) -> None:
|
|
||||||
print(f'board (step {self.step_num}):')
|
|
||||||
pad = '==='
|
|
||||||
top = pad + '=' * (2 * self.size - 1) + pad
|
|
||||||
print(top)
|
|
||||||
|
|
||||||
def f(i, data):
|
|
||||||
j, number = data
|
|
||||||
last_move = i == self._last_move[0] and j == self._last_move[1]
|
|
||||||
if number == 1:
|
|
||||||
return 'X' if last_move else 'x'
|
|
||||||
if number == -1:
|
|
||||||
return 'O' if last_move else 'o'
|
|
||||||
return '_'
|
|
||||||
|
|
||||||
for i, row in enumerate(self.current_board):
|
|
||||||
print(pad + ' '.join(map(partial(f, i), enumerate(row))) + pad)
|
|
||||||
print(top)
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
pass
|
|
||||||
185
test/pettingzoo/pistonball.py
Normal file
185
test/pettingzoo/pistonball.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
|
||||||
|
import torch
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=2000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument(
|
||||||
|
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--n-pistons',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help='Number of pistons(agents) in the env'
|
||||||
|
)
|
||||||
|
parser.add_argument('--n-step', type=int, default=100)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=3)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=100)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.0)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='no training, '
|
||||||
|
'watch the play of pre-trained models'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_args() -> argparse.Namespace:
|
||||||
|
parser = get_parser()
|
||||||
|
return parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_env(args: argparse.Namespace = get_args()):
|
||||||
|
return PettingZooEnv(pistonball_v4.env(continuous=False, n_pistons=args.n_pistons))
|
||||||
|
|
||||||
|
|
||||||
|
def get_agents(
|
||||||
|
args: argparse.Namespace = get_args(),
|
||||||
|
agents: Optional[List[BasePolicy]] = None,
|
||||||
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||||
|
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
|
||||||
|
env = get_env()
|
||||||
|
observation_space = env.observation_space['observation'] if isinstance(
|
||||||
|
env.observation_space, gym.spaces.Dict
|
||||||
|
) else env.observation_space
|
||||||
|
args.state_shape = observation_space.shape or observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
if agents is None:
|
||||||
|
agents = []
|
||||||
|
optims = []
|
||||||
|
for _ in range(args.n_pistons):
|
||||||
|
# model
|
||||||
|
net = Net(
|
||||||
|
args.state_shape,
|
||||||
|
args.action_shape,
|
||||||
|
hidden_sizes=args.hidden_sizes,
|
||||||
|
device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
agent = DQNPolicy(
|
||||||
|
net,
|
||||||
|
optim,
|
||||||
|
args.gamma,
|
||||||
|
args.n_step,
|
||||||
|
target_update_freq=args.target_update_freq
|
||||||
|
)
|
||||||
|
agents.append(agent)
|
||||||
|
optims.append(optim)
|
||||||
|
|
||||||
|
policy = MultiAgentPolicyManager(agents, env)
|
||||||
|
return policy, optims, env.agents
|
||||||
|
|
||||||
|
|
||||||
|
def train_agent(
|
||||||
|
args: argparse.Namespace = get_args(),
|
||||||
|
agents: Optional[List[BasePolicy]] = None,
|
||||||
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||||
|
) -> Tuple[dict, BasePolicy]:
|
||||||
|
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||||
|
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy,
|
||||||
|
train_envs,
|
||||||
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||||
|
exploration_noise=True
|
||||||
|
)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||||
|
|
||||||
|
def reward_metric(rews):
|
||||||
|
return rews[:, 0]
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.step_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
train_fn=train_fn,
|
||||||
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
update_per_step=args.update_per_step,
|
||||||
|
logger=logger,
|
||||||
|
test_in_train=False,
|
||||||
|
reward_metric=reward_metric
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, policy
|
||||||
|
|
||||||
|
|
||||||
|
def watch(
|
||||||
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
|
) -> None:
|
||||||
|
env = get_env()
|
||||||
|
policy.eval()
|
||||||
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||||
|
collector = Collector(policy, env, exploration_noise=True)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")
|
||||||
276
test/pettingzoo/pistonball_continuous.py
Normal file
276
test/pettingzoo/pistonball_continuous.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import pettingzoo.butterfly.pistonball_v4 as pistonball_v4
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.distributions import Independent, Normal
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy
|
||||||
|
from tianshou.trainer import onpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
|
class DQN(nn.Module):
|
||||||
|
"""Reference: Human-level control through deep reinforcement learning.
|
||||||
|
|
||||||
|
For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c: int,
|
||||||
|
h: int,
|
||||||
|
w: int,
|
||||||
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.c = c
|
||||||
|
self.h = h
|
||||||
|
self.w = w
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
||||||
|
nn.Flatten()
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.output_dim = np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Union[np.ndarray, torch.Tensor],
|
||||||
|
state: Optional[Any] = None,
|
||||||
|
info: Dict[str, Any] = {},
|
||||||
|
) -> Tuple[torch.Tensor, Any]:
|
||||||
|
r"""Mapping: x -> Q(x, \*)."""
|
||||||
|
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
|
||||||
|
return self.net(x.reshape(-1, self.c, self.w, self.h)), state
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=2000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument(
|
||||||
|
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--n-pistons',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help='Number of pistons(agents) in the env'
|
||||||
|
)
|
||||||
|
parser.add_argument('--n-step', type=int, default=100)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--episode-per-collect', type=int, default=16)
|
||||||
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=1000)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=1000)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--watch',
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help='no training, '
|
||||||
|
'watch the play of pre-trained models'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
# ppo special
|
||||||
|
parser.add_argument('--vf-coef', type=float, default=0.25)
|
||||||
|
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||||
|
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||||
|
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
|
parser.add_argument('--rew-norm', type=int, default=1)
|
||||||
|
parser.add_argument('--dual-clip', type=float, default=None)
|
||||||
|
parser.add_argument('--value-clip', type=int, default=1)
|
||||||
|
parser.add_argument('--norm-adv', type=int, default=1)
|
||||||
|
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||||
|
parser.add_argument('--resume', action="store_true")
|
||||||
|
parser.add_argument("--save-interval", type=int, default=4)
|
||||||
|
parser.add_argument('--render', type=float, default=0.0)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_args() -> argparse.Namespace:
|
||||||
|
parser = get_parser()
|
||||||
|
return parser.parse_known_args()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def get_env(args: argparse.Namespace = get_args()):
|
||||||
|
return PettingZooEnv(pistonball_v4.env(continuous=True, n_pistons=args.n_pistons))
|
||||||
|
|
||||||
|
|
||||||
|
def get_agents(
|
||||||
|
args: argparse.Namespace = get_args(),
|
||||||
|
agents: Optional[List[BasePolicy]] = None,
|
||||||
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||||
|
) -> Tuple[BasePolicy, List[torch.optim.Optimizer], List]:
|
||||||
|
env = get_env()
|
||||||
|
observation_space = env.observation_space['observation'] if isinstance(
|
||||||
|
env.observation_space, gym.spaces.Dict
|
||||||
|
) else env.observation_space
|
||||||
|
args.state_shape = observation_space.shape or observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0]
|
||||||
|
|
||||||
|
if agents is None:
|
||||||
|
agents = []
|
||||||
|
optims = []
|
||||||
|
for _ in range(args.n_pistons):
|
||||||
|
# model
|
||||||
|
net = DQN(
|
||||||
|
observation_space.shape[2],
|
||||||
|
observation_space.shape[1],
|
||||||
|
observation_space.shape[0],
|
||||||
|
device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
|
||||||
|
actor = ActorProb(
|
||||||
|
net, args.action_shape, max_action=args.max_action, device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
net2 = DQN(
|
||||||
|
observation_space.shape[2],
|
||||||
|
observation_space.shape[1],
|
||||||
|
observation_space.shape[0],
|
||||||
|
device=args.device
|
||||||
|
).to(args.device)
|
||||||
|
critic = Critic(net2, device=args.device).to(args.device)
|
||||||
|
for m in set(actor.modules()).union(critic.modules()):
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
optim = torch.optim.Adam(
|
||||||
|
set(actor.parameters()).union(critic.parameters()), lr=args.lr
|
||||||
|
)
|
||||||
|
|
||||||
|
def dist(*logits):
|
||||||
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
|
agent = PPOPolicy(
|
||||||
|
actor,
|
||||||
|
critic,
|
||||||
|
optim,
|
||||||
|
dist,
|
||||||
|
discount_factor=args.gamma,
|
||||||
|
max_grad_norm=args.max_grad_norm,
|
||||||
|
eps_clip=args.eps_clip,
|
||||||
|
vf_coef=args.vf_coef,
|
||||||
|
ent_coef=args.ent_coef,
|
||||||
|
reward_normalization=args.rew_norm,
|
||||||
|
advantage_normalization=args.norm_adv,
|
||||||
|
recompute_advantage=args.recompute_adv,
|
||||||
|
# dual_clip=args.dual_clip,
|
||||||
|
# dual clip cause monotonically increasing log_std :)
|
||||||
|
value_clip=args.value_clip,
|
||||||
|
gae_lambda=args.gae_lambda,
|
||||||
|
action_space=env.action_space
|
||||||
|
)
|
||||||
|
|
||||||
|
agents.append(agent)
|
||||||
|
optims.append(optim)
|
||||||
|
|
||||||
|
policy = MultiAgentPolicyManager(
|
||||||
|
agents, env, action_scaling=True, action_bound_method='clip'
|
||||||
|
)
|
||||||
|
return policy, optims, env.agents
|
||||||
|
|
||||||
|
|
||||||
|
def train_agent(
|
||||||
|
args: argparse.Namespace = get_args(),
|
||||||
|
agents: Optional[List[BasePolicy]] = None,
|
||||||
|
optims: Optional[List[torch.optim.Optimizer]] = None,
|
||||||
|
) -> Tuple[dict, BasePolicy]:
|
||||||
|
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||||
|
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
|
||||||
|
policy, optim, agents = get_agents(args, agents=agents, optims=optims)
|
||||||
|
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy,
|
||||||
|
train_envs,
|
||||||
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||||
|
exploration_noise=False # True
|
||||||
|
)
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, 'pistonball', 'dqn')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
[agent.set_eps(args.eps_train) for agent in policy.policies.values()]
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||||
|
|
||||||
|
def reward_metric(rews):
|
||||||
|
return rews[:, 0]
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = onpolicy_trainer(
|
||||||
|
policy,
|
||||||
|
train_collector,
|
||||||
|
test_collector,
|
||||||
|
args.epoch,
|
||||||
|
args.step_per_epoch,
|
||||||
|
args.repeat_per_collect,
|
||||||
|
args.test_num,
|
||||||
|
args.batch_size,
|
||||||
|
episode_per_collect=args.episode_per_collect,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
|
logger=logger,
|
||||||
|
resume_from_log=args.resume
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, policy
|
||||||
|
|
||||||
|
|
||||||
|
def watch(
|
||||||
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
|
) -> None:
|
||||||
|
env = get_env()
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}")
|
||||||
21
test/pettingzoo/test_pistonball.py
Normal file
21
test/pettingzoo/test_pistonball.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import pprint
|
||||||
|
|
||||||
|
from pistonball import get_args, train_agent, watch
|
||||||
|
|
||||||
|
|
||||||
|
def test_piston_ball(args=get_args()):
|
||||||
|
if args.watch:
|
||||||
|
watch(args)
|
||||||
|
return
|
||||||
|
|
||||||
|
result, agent = train_agent(args)
|
||||||
|
# assert result["best_reward"] >= args.win_rate
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
watch(args, agent)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_piston_ball(get_args())
|
||||||
21
test/pettingzoo/test_pistonball_continuous.py
Normal file
21
test/pettingzoo/test_pistonball_continuous.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import pprint
|
||||||
|
|
||||||
|
from pistonball_continuous import get_args, train_agent, watch
|
||||||
|
|
||||||
|
|
||||||
|
def test_piston_ball_continuous(args=get_args()):
|
||||||
|
if args.watch:
|
||||||
|
watch(args)
|
||||||
|
return
|
||||||
|
|
||||||
|
result, agent = train_agent(args)
|
||||||
|
assert result["best_reward"] >= 30.0
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
watch(args, agent)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_piston_ball_continuous(get_args())
|
||||||
@ -1,21 +1,21 @@
|
|||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
from tic_tac_toe import get_args, train_agent, watch
|
from tic_tac_toe import get_args, train_agent, watch
|
||||||
|
|
||||||
|
|
||||||
def test_tic_tac_toe(args=get_args()):
|
def test_tic_tac_toe(args=get_args()):
|
||||||
if args.watch:
|
if args.watch:
|
||||||
watch(args)
|
watch(args)
|
||||||
return
|
return
|
||||||
|
|
||||||
result, agent = train_agent(args)
|
result, agent = train_agent(args)
|
||||||
assert result["best_reward"] >= args.win_rate
|
assert result["best_reward"] >= args.win_rate
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
watch(args, agent)
|
watch(args, agent)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_tic_tac_toe(get_args())
|
test_tic_tac_toe(get_args())
|
||||||
@ -1,232 +1,241 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import gym
|
||||||
import torch
|
import numpy as np
|
||||||
from tic_tac_toe_env import TicTacToeEnv
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from pettingzoo.classic import tictactoe_v3
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
from tianshou.policy import (
|
from tianshou.env import DummyVectorEnv
|
||||||
BasePolicy,
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
DQNPolicy,
|
from tianshou.policy import (
|
||||||
MultiAgentPolicyManager,
|
BasePolicy,
|
||||||
RandomPolicy,
|
DQNPolicy,
|
||||||
)
|
MultiAgentPolicyManager,
|
||||||
from tianshou.trainer import offpolicy_trainer
|
RandomPolicy,
|
||||||
from tianshou.utils import TensorboardLogger
|
)
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.utils import TensorboardLogger
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
def get_parser() -> argparse.ArgumentParser:
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
def get_env():
|
||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
return PettingZooEnv(tictactoe_v3.env())
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
parser.add_argument(
|
parser = argparse.ArgumentParser()
|
||||||
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--n-step', type=int, default=3)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--epoch', type=int, default=20)
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
parser.add_argument(
|
||||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
'--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win'
|
||||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument(
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
parser.add_argument('--epoch', type=int, default=50)
|
||||||
)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--training-num', type=int, default=10)
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--render', type=float, default=0.1)
|
parser.add_argument(
|
||||||
parser.add_argument('--board-size', type=int, default=6)
|
'--hidden-sizes', type=int, nargs='*', default=[128, 128, 128, 128]
|
||||||
parser.add_argument('--win-size', type=int, default=4)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
'--win-rate', type=float, default=0.9, help='the expected winning rate'
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
)
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument(
|
parser.add_argument('--render', type=float, default=0.1)
|
||||||
'--watch',
|
parser.add_argument(
|
||||||
default=False,
|
'--win-rate',
|
||||||
action='store_true',
|
type=float,
|
||||||
help='no training, '
|
default=0.6,
|
||||||
'watch the play of pre-trained models'
|
help='the expected winning rate: Optimal policy can get 0.7'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--agent-id',
|
'--watch',
|
||||||
type=int,
|
default=False,
|
||||||
default=2,
|
action='store_true',
|
||||||
help='the learned agent plays as the'
|
help='no training, '
|
||||||
' agent_id-th player. Choices are 1 and 2.'
|
'watch the play of pre-trained models'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--resume-path',
|
'--agent-id',
|
||||||
type=str,
|
type=int,
|
||||||
default='',
|
default=2,
|
||||||
help='the path of agent pth file '
|
help='the learned agent plays as the'
|
||||||
'for resuming from a pre-trained agent'
|
' agent_id-th player. Choices are 1 and 2.'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--opponent-path',
|
'--resume-path',
|
||||||
type=str,
|
type=str,
|
||||||
default='',
|
default='',
|
||||||
help='the path of opponent agent pth file '
|
help='the path of agent pth file '
|
||||||
'for resuming from a pre-trained agent'
|
'for resuming from a pre-trained agent'
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
'--opponent-path',
|
||||||
)
|
type=str,
|
||||||
return parser
|
default='',
|
||||||
|
help='the path of opponent agent pth file '
|
||||||
|
'for resuming from a pre-trained agent'
|
||||||
def get_args() -> argparse.Namespace:
|
)
|
||||||
parser = get_parser()
|
parser.add_argument(
|
||||||
return parser.parse_known_args()[0]
|
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
)
|
||||||
|
return parser
|
||||||
def get_agents(
|
|
||||||
args: argparse.Namespace = get_args(),
|
|
||||||
agent_learn: Optional[BasePolicy] = None,
|
def get_args() -> argparse.Namespace:
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
parser = get_parser()
|
||||||
optim: Optional[torch.optim.Optimizer] = None,
|
return parser.parse_known_args()[0]
|
||||||
) -> Tuple[BasePolicy, torch.optim.Optimizer]:
|
|
||||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
def get_agents(
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args: argparse.Namespace = get_args(),
|
||||||
if agent_learn is None:
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
# model
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
net = Net(
|
optim: Optional[torch.optim.Optimizer] = None,
|
||||||
args.state_shape,
|
) -> Tuple[BasePolicy, torch.optim.Optimizer, list]:
|
||||||
args.action_shape,
|
env = get_env()
|
||||||
hidden_sizes=args.hidden_sizes,
|
observation_space = env.observation_space['observation'] if isinstance(
|
||||||
device=args.device
|
env.observation_space, gym.spaces.Dict
|
||||||
).to(args.device)
|
) else env.observation_space
|
||||||
if optim is None:
|
args.state_shape = observation_space.shape or observation_space.n
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
agent_learn = DQNPolicy(
|
if agent_learn is None:
|
||||||
net,
|
# model
|
||||||
optim,
|
net = Net(
|
||||||
args.gamma,
|
args.state_shape,
|
||||||
args.n_step,
|
args.action_shape,
|
||||||
target_update_freq=args.target_update_freq
|
hidden_sizes=args.hidden_sizes,
|
||||||
)
|
device=args.device
|
||||||
if args.resume_path:
|
).to(args.device)
|
||||||
agent_learn.load_state_dict(torch.load(args.resume_path))
|
if optim is None:
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
if agent_opponent is None:
|
agent_learn = DQNPolicy(
|
||||||
if args.opponent_path:
|
net,
|
||||||
agent_opponent = deepcopy(agent_learn)
|
optim,
|
||||||
agent_opponent.load_state_dict(torch.load(args.opponent_path))
|
args.gamma,
|
||||||
else:
|
args.n_step,
|
||||||
agent_opponent = RandomPolicy()
|
target_update_freq=args.target_update_freq
|
||||||
|
)
|
||||||
if args.agent_id == 1:
|
if args.resume_path:
|
||||||
agents = [agent_learn, agent_opponent]
|
agent_learn.load_state_dict(torch.load(args.resume_path))
|
||||||
else:
|
|
||||||
agents = [agent_opponent, agent_learn]
|
if agent_opponent is None:
|
||||||
policy = MultiAgentPolicyManager(agents)
|
if args.opponent_path:
|
||||||
return policy, optim
|
agent_opponent = deepcopy(agent_learn)
|
||||||
|
agent_opponent.load_state_dict(torch.load(args.opponent_path))
|
||||||
|
else:
|
||||||
def train_agent(
|
agent_opponent = RandomPolicy()
|
||||||
args: argparse.Namespace = get_args(),
|
|
||||||
agent_learn: Optional[BasePolicy] = None,
|
if args.agent_id == 1:
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
agents = [agent_learn, agent_opponent]
|
||||||
optim: Optional[torch.optim.Optimizer] = None,
|
else:
|
||||||
) -> Tuple[dict, BasePolicy]:
|
agents = [agent_opponent, agent_learn]
|
||||||
|
policy = MultiAgentPolicyManager(agents, env)
|
||||||
def env_func():
|
return policy, optim, env.agents
|
||||||
return TicTacToeEnv(args.board_size, args.win_size)
|
|
||||||
|
|
||||||
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
|
def train_agent(
|
||||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
args: argparse.Namespace = get_args(),
|
||||||
# seed
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
np.random.seed(args.seed)
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
torch.manual_seed(args.seed)
|
optim: Optional[torch.optim.Optimizer] = None,
|
||||||
train_envs.seed(args.seed)
|
) -> Tuple[dict, BasePolicy]:
|
||||||
test_envs.seed(args.seed)
|
|
||||||
|
train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)])
|
||||||
policy, optim = get_agents(
|
test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)])
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
|
# seed
|
||||||
)
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
# collector
|
train_envs.seed(args.seed)
|
||||||
train_collector = Collector(
|
test_envs.seed(args.seed)
|
||||||
policy,
|
|
||||||
train_envs,
|
policy, optim, agents = get_agents(
|
||||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent, optim=optim
|
||||||
exploration_noise=True
|
)
|
||||||
)
|
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
# collector
|
||||||
# policy.set_eps(1)
|
train_collector = Collector(
|
||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
policy,
|
||||||
# log
|
train_envs,
|
||||||
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||||
writer = SummaryWriter(log_path)
|
exploration_noise=True
|
||||||
writer.add_text("args", str(args))
|
)
|
||||||
logger = TensorboardLogger(writer)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# policy.set_eps(1)
|
||||||
def save_fn(policy):
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
if hasattr(args, 'model_save_path'):
|
# log
|
||||||
model_save_path = args.model_save_path
|
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
|
||||||
else:
|
writer = SummaryWriter(log_path)
|
||||||
model_save_path = os.path.join(
|
writer.add_text("args", str(args))
|
||||||
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
|
logger = TensorboardLogger(writer)
|
||||||
)
|
|
||||||
torch.save(policy.policies[args.agent_id - 1].state_dict(), model_save_path)
|
def save_fn(policy):
|
||||||
|
if hasattr(args, 'model_save_path'):
|
||||||
def stop_fn(mean_rewards):
|
model_save_path = args.model_save_path
|
||||||
return mean_rewards >= args.win_rate
|
else:
|
||||||
|
model_save_path = os.path.join(
|
||||||
def train_fn(epoch, env_step):
|
args.logdir, 'tic_tac_toe', 'dqn', 'policy.pth'
|
||||||
policy.policies[args.agent_id - 1].set_eps(args.eps_train)
|
)
|
||||||
|
torch.save(
|
||||||
def test_fn(epoch, env_step):
|
policy.policies[agents[args.agent_id - 1]].state_dict(), model_save_path
|
||||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
)
|
||||||
|
|
||||||
def reward_metric(rews):
|
def stop_fn(mean_rewards):
|
||||||
return rews[:, args.agent_id - 1]
|
return mean_rewards >= args.win_rate
|
||||||
|
|
||||||
# trainer
|
def train_fn(epoch, env_step):
|
||||||
result = offpolicy_trainer(
|
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_train)
|
||||||
policy,
|
|
||||||
train_collector,
|
def test_fn(epoch, env_step):
|
||||||
test_collector,
|
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||||
args.epoch,
|
|
||||||
args.step_per_epoch,
|
def reward_metric(rews):
|
||||||
args.step_per_collect,
|
return rews[:, args.agent_id - 1]
|
||||||
args.test_num,
|
|
||||||
args.batch_size,
|
# trainer
|
||||||
train_fn=train_fn,
|
result = offpolicy_trainer(
|
||||||
test_fn=test_fn,
|
policy,
|
||||||
stop_fn=stop_fn,
|
train_collector,
|
||||||
save_fn=save_fn,
|
test_collector,
|
||||||
update_per_step=args.update_per_step,
|
args.epoch,
|
||||||
logger=logger,
|
args.step_per_epoch,
|
||||||
test_in_train=False,
|
args.step_per_collect,
|
||||||
reward_metric=reward_metric
|
args.test_num,
|
||||||
)
|
args.batch_size,
|
||||||
|
train_fn=train_fn,
|
||||||
return result, policy.policies[args.agent_id - 1]
|
test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn,
|
||||||
|
save_fn=save_fn,
|
||||||
def watch(
|
update_per_step=args.update_per_step,
|
||||||
args: argparse.Namespace = get_args(),
|
logger=logger,
|
||||||
agent_learn: Optional[BasePolicy] = None,
|
test_in_train=False,
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
reward_metric=reward_metric
|
||||||
) -> None:
|
)
|
||||||
env = TicTacToeEnv(args.board_size, args.win_size)
|
|
||||||
policy, optim = get_agents(
|
return result, policy.policies[agents[args.agent_id - 1]]
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
|
||||||
)
|
|
||||||
policy.eval()
|
def watch(
|
||||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
args: argparse.Namespace = get_args(),
|
||||||
collector = Collector(policy, env, exploration_noise=True)
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
rews, lens = result["rews"], result["lens"]
|
) -> None:
|
||||||
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
env = get_env()
|
||||||
|
policy, optim, agents = get_agents(
|
||||||
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||||
|
)
|
||||||
|
policy.eval()
|
||||||
|
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||||
|
collector = Collector(policy, env, exploration_noise=True)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||||
4
tianshou/env/__init__.py
vendored
4
tianshou/env/__init__.py
vendored
@ -1,6 +1,6 @@
|
|||||||
"""Env package."""
|
"""Env package."""
|
||||||
|
|
||||||
from tianshou.env.maenv import MultiAgentEnv
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
from tianshou.env.venvs import (
|
from tianshou.env.venvs import (
|
||||||
BaseVectorEnv,
|
BaseVectorEnv,
|
||||||
DummyVectorEnv,
|
DummyVectorEnv,
|
||||||
@ -15,5 +15,5 @@ __all__ = [
|
|||||||
"SubprocVectorEnv",
|
"SubprocVectorEnv",
|
||||||
"ShmemVectorEnv",
|
"ShmemVectorEnv",
|
||||||
"RayVectorEnv",
|
"RayVectorEnv",
|
||||||
"MultiAgentEnv",
|
"PettingZooEnv",
|
||||||
]
|
]
|
||||||
|
|||||||
65
tianshou/env/maenv.py
vendored
65
tianshou/env/maenv.py
vendored
@ -1,65 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, Tuple
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class MultiAgentEnv(ABC, gym.Env):
|
|
||||||
"""The interface for multi-agent environments.
|
|
||||||
|
|
||||||
Multi-agent environments must be wrapped as
|
|
||||||
:class:`~tianshou.env.MultiAgentEnv`. Here is the usage:
|
|
||||||
::
|
|
||||||
|
|
||||||
env = MultiAgentEnv(...)
|
|
||||||
# obs is a dict containing obs, agent_id, and mask
|
|
||||||
obs = env.reset()
|
|
||||||
act = policy(obs)
|
|
||||||
obs, rew, done, info = env.step(act)
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
The available action's mask is set to 1, otherwise it is set to 0. Further
|
|
||||||
usage can be found at :ref:`marl_example`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset(self) -> dict:
|
|
||||||
"""Reset the state.
|
|
||||||
|
|
||||||
Return the initial state, first agent_id, and the initial action set,
|
|
||||||
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def step(
|
|
||||||
self, action: np.ndarray
|
|
||||||
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""Run one timestep of the environment’s dynamics.
|
|
||||||
|
|
||||||
When the end of episode is reached, you are responsible for calling
|
|
||||||
reset() to reset the environment’s state.
|
|
||||||
|
|
||||||
Accept action and return a tuple (obs, rew, done, info).
|
|
||||||
|
|
||||||
:param numpy.ndarray action: action provided by a agent.
|
|
||||||
|
|
||||||
:return: A tuple including four items:
|
|
||||||
|
|
||||||
* ``obs`` a dict containing obs, agent_id, and mask, which means \
|
|
||||||
that it is the ``agent_id`` player's turn to play with ``obs``\
|
|
||||||
observation and ``mask``.
|
|
||||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
|
||||||
previous actions. Depending on the specific environment, this \
|
|
||||||
can be either a scalar reward for current agent or a vector \
|
|
||||||
reward for all the agents.
|
|
||||||
* ``done`` a numpy.ndarray, whether the episode has ended, in \
|
|
||||||
which case further step() calls will return undefined results
|
|
||||||
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
|
|
||||||
information (helpful for debugging, and sometimes learning)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
112
tianshou/env/pettingzoo_env.py
vendored
Normal file
112
tianshou/env/pettingzoo_env.py
vendored
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import gym.spaces
|
||||||
|
from pettingzoo.utils.env import AECEnv
|
||||||
|
from pettingzoo.utils.wrappers import BaseWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class PettingZooEnv(AECEnv, gym.Env, ABC):
|
||||||
|
"""The interface for petting zoo environments.
|
||||||
|
|
||||||
|
Multi-agent environments must be wrapped as
|
||||||
|
:class:`~tianshou.env.PettingZooEnv`. Here is the usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
env = PettingZooEnv(...)
|
||||||
|
# obs is a dict containing obs, agent_id, and mask
|
||||||
|
obs = env.reset()
|
||||||
|
action = policy(obs)
|
||||||
|
obs, rew, done, info = env.step(action)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
The available action's mask is set to True, otherwise it is set to False.
|
||||||
|
Further usage can be found at :ref:`marl_example`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, env: BaseWrapper):
|
||||||
|
super().__init__()
|
||||||
|
self.env = env
|
||||||
|
# agent idx list
|
||||||
|
self.agents = self.env.possible_agents
|
||||||
|
self.agent_idx = {}
|
||||||
|
for i, agent_id in enumerate(self.agents):
|
||||||
|
self.agent_idx[agent_id] = i
|
||||||
|
# Get dictionaries of obs_spaces and act_spaces
|
||||||
|
self.observation_spaces = self.env.observation_spaces
|
||||||
|
self.action_spaces = self.env.action_spaces
|
||||||
|
|
||||||
|
self.rewards = [0] * len(self.agents)
|
||||||
|
|
||||||
|
# Get first observation space, assuming all agents have equal space
|
||||||
|
self.observation_space: Any = self.observation_space(self.agents[0])
|
||||||
|
|
||||||
|
# Get first action space, assuming all agents have equal space
|
||||||
|
self.action_space: Any = self.action_space(self.agents[0])
|
||||||
|
|
||||||
|
assert all(self.env.observation_space(agent) == self.observation_space
|
||||||
|
for agent in self.agents), \
|
||||||
|
"Observation spaces for all agents must be identical. Perhaps " \
|
||||||
|
"SuperSuit's pad_observations wrapper can help (useage: " \
|
||||||
|
"`supersuit.aec_wrappers.pad_observations(env)`"
|
||||||
|
|
||||||
|
assert all(self.env.action_space(agent) == self.action_space
|
||||||
|
for agent in self.agents), \
|
||||||
|
"Action spaces for all agents must be identical. Perhaps " \
|
||||||
|
"SuperSuit's pad_action_space wrapper can help (useage: " \
|
||||||
|
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> dict:
|
||||||
|
self.env.reset()
|
||||||
|
observation = self.env.observe(self.env.agent_selection)
|
||||||
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||||
|
return {
|
||||||
|
'agent_id': self.env.agent_selection,
|
||||||
|
'obs': observation['observation'],
|
||||||
|
'mask':
|
||||||
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||||
|
return {
|
||||||
|
'agent_id': self.env.agent_selection,
|
||||||
|
'obs': observation,
|
||||||
|
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {'agent_id': self.env.agent_selection, 'obs': observation}
|
||||||
|
|
||||||
|
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
|
||||||
|
self.env.step(action)
|
||||||
|
observation, rew, done, info = self.env.last()
|
||||||
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||||
|
obs = {
|
||||||
|
'agent_id': self.env.agent_selection,
|
||||||
|
'obs': observation['observation'],
|
||||||
|
'mask':
|
||||||
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||||
|
obs = {
|
||||||
|
'agent_id': self.env.agent_selection,
|
||||||
|
'obs': observation,
|
||||||
|
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
|
||||||
|
|
||||||
|
for agent_id, reward in self.env.rewards.items():
|
||||||
|
self.rewards[self.agent_idx[agent_id]] = reward
|
||||||
|
return obs, self.rewards, done, info
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.env.close()
|
||||||
|
|
||||||
|
def seed(self, seed: Any = None) -> None:
|
||||||
|
self.env.seed(seed)
|
||||||
|
|
||||||
|
def render(self, mode: str = "human") -> Any:
|
||||||
|
return self.env.render(mode)
|
||||||
6
tianshou/env/venvs.py
vendored
6
tianshou/env/venvs.py
vendored
@ -2,6 +2,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pettingzoo
|
||||||
|
|
||||||
from tianshou.env.worker import (
|
from tianshou.env.worker import (
|
||||||
DummyEnvWorker,
|
DummyEnvWorker,
|
||||||
@ -364,7 +365,10 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self, env_fns: List[Callable[[], Union[gym.Env, pettingzoo.AECEnv]]],
|
||||||
|
**kwargs: Any
|
||||||
|
) -> None:
|
||||||
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
@ -16,21 +17,29 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
:ref:`marl_example` can help you better understand this procedure.
|
:ref:`marl_example` can help you better understand this procedure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
|
def __init__(
|
||||||
super().__init__(**kwargs)
|
self, policies: List[BasePolicy], env: PettingZooEnv, **kwargs: Any
|
||||||
self.policies = policies
|
) -> None:
|
||||||
|
super().__init__(action_space=env.action_space, **kwargs)
|
||||||
|
assert (
|
||||||
|
len(policies) == len(env.agents)
|
||||||
|
), "One policy must be assigned for each agent."
|
||||||
|
|
||||||
|
self.agent_idx = env.agent_idx
|
||||||
for i, policy in enumerate(policies):
|
for i, policy in enumerate(policies):
|
||||||
# agent_id 0 is reserved for the environment proxy
|
# agent_id 0 is reserved for the environment proxy
|
||||||
# (this MultiAgentPolicyManager)
|
# (this MultiAgentPolicyManager)
|
||||||
policy.set_agent_id(i + 1)
|
policy.set_agent_id(env.agents[i])
|
||||||
|
|
||||||
|
self.policies = dict(zip(env.agents, policies))
|
||||||
|
|
||||||
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
|
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
|
||||||
"""Replace the "agent_id"th policy in this manager."""
|
"""Replace the "agent_id"th policy in this manager."""
|
||||||
self.policies[agent_id - 1] = policy
|
|
||||||
policy.set_agent_id(agent_id)
|
policy.set_agent_id(agent_id)
|
||||||
|
self.policies[agent_id] = policy
|
||||||
|
|
||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
|
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
|
||||||
|
|
||||||
@ -45,18 +54,21 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
# Since we do not override buffer.__setattr__, here we use _meta to
|
# Since we do not override buffer.__setattr__, here we use _meta to
|
||||||
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
|
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
|
||||||
save_rew, buffer._meta.rew = buffer.rew, Batch()
|
save_rew, buffer._meta.rew = buffer.rew, Batch()
|
||||||
for policy in self.policies:
|
for agent, policy in self.policies.items():
|
||||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
agent_index = np.nonzero(batch.obs.agent_id == agent)[0]
|
||||||
if len(agent_index) == 0:
|
if len(agent_index) == 0:
|
||||||
results[f"agent_{policy.agent_id}"] = Batch()
|
results[agent] = Batch()
|
||||||
continue
|
continue
|
||||||
tmp_batch, tmp_indices = batch[agent_index], indices[agent_index]
|
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
|
||||||
if has_rew:
|
if has_rew:
|
||||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent]]
|
||||||
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
|
buffer._meta.rew = save_rew[:, self.agent_idx[agent]]
|
||||||
results[f"agent_{policy.agent_id}"] = policy.process_fn(
|
if not hasattr(tmp_batch.obs, "mask"):
|
||||||
tmp_batch, buffer, tmp_indices
|
if hasattr(tmp_batch.obs, 'obs'):
|
||||||
)
|
tmp_batch.obs = tmp_batch.obs.obs
|
||||||
|
if hasattr(tmp_batch.obs_next, 'obs'):
|
||||||
|
tmp_batch.obs_next = tmp_batch.obs_next.obs
|
||||||
|
results[agent] = policy.process_fn(tmp_batch, buffer, tmp_indice)
|
||||||
if has_rew: # restore from save_rew
|
if has_rew: # restore from save_rew
|
||||||
buffer._meta.rew = save_rew
|
buffer._meta.rew = save_rew
|
||||||
return Batch(results)
|
return Batch(results)
|
||||||
@ -64,8 +76,8 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
def exploration_noise(self, act: Union[np.ndarray, Batch],
|
||||||
batch: Batch) -> Union[np.ndarray, Batch]:
|
batch: Batch) -> Union[np.ndarray, Batch]:
|
||||||
"""Add exploration noise from sub-policy onto act."""
|
"""Add exploration noise from sub-policy onto act."""
|
||||||
for policy in self.policies:
|
for agent_id, policy in self.policies.items():
|
||||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
|
||||||
if len(agent_index) == 0:
|
if len(agent_index) == 0:
|
||||||
continue
|
continue
|
||||||
act[agent_index] = policy.exploration_noise(
|
act[agent_index] = policy.exploration_noise(
|
||||||
@ -104,7 +116,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
"""
|
"""
|
||||||
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
|
results: List[Tuple[bool, np.ndarray, Batch, Union[np.ndarray, Batch],
|
||||||
Batch]] = []
|
Batch]] = []
|
||||||
for policy in self.policies:
|
for agent_id, policy in self.policies.items():
|
||||||
# This part of code is difficult to understand.
|
# This part of code is difficult to understand.
|
||||||
# Let's follow an example with two agents
|
# Let's follow an example with two agents
|
||||||
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
|
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
|
||||||
@ -112,7 +124,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
# agent_index for agent 1 is [0, 2, 4]
|
# agent_index for agent 1 is [0, 2, 4]
|
||||||
# agent_index for agent 2 is [1, 3, 5]
|
# agent_index for agent 2 is [1, 3, 5]
|
||||||
# we separate the transition of each agent according to agent_id
|
# we separate the transition of each agent according to agent_id
|
||||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
|
||||||
if len(agent_index) == 0:
|
if len(agent_index) == 0:
|
||||||
# (has_data, agent_index, out, act, state)
|
# (has_data, agent_index, out, act, state)
|
||||||
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
||||||
@ -120,11 +132,15 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
tmp_batch = batch[agent_index]
|
tmp_batch = batch[agent_index]
|
||||||
if isinstance(tmp_batch.rew, np.ndarray):
|
if isinstance(tmp_batch.rew, np.ndarray):
|
||||||
# reward can be empty Batch (after initial reset) or nparray.
|
# reward can be empty Batch (after initial reset) or nparray.
|
||||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
|
||||||
|
if not hasattr(tmp_batch.obs, "mask"):
|
||||||
|
if hasattr(tmp_batch.obs, 'obs'):
|
||||||
|
tmp_batch.obs = tmp_batch.obs.obs
|
||||||
|
if hasattr(tmp_batch.obs_next, 'obs'):
|
||||||
|
tmp_batch.obs_next = tmp_batch.obs_next.obs
|
||||||
out = policy(
|
out = policy(
|
||||||
batch=tmp_batch,
|
batch=tmp_batch,
|
||||||
state=None if state is None else state["agent_" +
|
state=None if state is None else state[agent_id],
|
||||||
str(policy.agent_id)],
|
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
act = out.act
|
act = out.act
|
||||||
@ -141,12 +157,12 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
state_dict, out_dict = {}, {}
|
state_dict, out_dict = {}, {}
|
||||||
for policy, (has_data, agent_index, out, act,
|
for (agent_id, _), (has_data, agent_index, out, act,
|
||||||
state) in zip(self.policies, results):
|
state) in zip(self.policies.items(), results):
|
||||||
if has_data:
|
if has_data:
|
||||||
holder.act[agent_index] = act
|
holder.act[agent_index] = act
|
||||||
state_dict["agent_" + str(policy.agent_id)] = state
|
state_dict[agent_id] = state
|
||||||
out_dict["agent_" + str(policy.agent_id)] = out
|
out_dict[agent_id] = out
|
||||||
holder["out"] = out_dict
|
holder["out"] = out_dict
|
||||||
holder["state"] = state_dict
|
holder["state"] = state_dict
|
||||||
return holder
|
return holder
|
||||||
@ -168,10 +184,10 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = {}
|
||||||
for policy in self.policies:
|
for agent_id, policy in self.policies.items():
|
||||||
data = batch[f"agent_{policy.agent_id}"]
|
data = batch[agent_id]
|
||||||
if not data.is_empty():
|
if not data.is_empty():
|
||||||
out = policy.learn(batch=data, **kwargs)
|
out = policy.learn(batch=data, **kwargs)
|
||||||
for k, v in out.items():
|
for k, v in out.items():
|
||||||
results["agent_" + str(policy.agent_id) + "/" + k] = v
|
results[agent_id + "/" + k] = v
|
||||||
return results
|
return results
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user