youkaichao 8c32d99c65
Add multi-agent example: tic-tac-toe (#122)
* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* Improve Batch (#126)

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* minor polish

* remove multibuf

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* dummy code

* remove dummy

* add multi-agent example: tic-tac-toe

* move TicTacToeEnv to a separate file

* remove dummy MANet

* code refactor

* move tic-tac-toe example to test

* update doc with marl-example

* fix docs

* reduce the threshold

* revert

* update player id to start from 1 and change player to agent; keep coding

* add reward_length argument for collector

* Improve Batch (#128)

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* fix docs

* fix docs

* fix docs [ci skip]

* fix docs [ci skip]

Co-authored-by: Trinkle23897 <463003665@qq.com>

* refact

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reward_metric

* modify flag

* minor fix

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix reward stat in collector

* fix stat of collector, simplify test/base/env.py

* fix docs

* minor fix

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

* minor fix

* marl-examples

* add condense; bugfix for torch.Tensor; code refactor

* marl example can run now

* enable tic tac toe with larger board size and win-size

* add test dependency

* Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix docs

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>

* stash

* let agent learn to play as agent 2 which is harder

* code refactor

* Improve collector (#125)

* remove multibuf

* reward_metric

* make fileds with empty Batch rather than None after reset

* many fixes and refactor
Co-authored-by: Trinkle23897 <463003665@qq.com>

* marl for tic-tac-toe and general gomoku

* update default gamma to 0.1 for tic tac toe to win earlier

* fix name typo; change default game config; add rew_norm option

* fix pep8

* test commit

* mv test dir name

* add rew flag

* fix torch.optim import error and madqn rew_norm

* remove useless kwargs

* Vector env enable select worker (#132)

* Enable selecting worker for vector env step method.

* Update collector to match new vecenv selective worker behavior.

* Bug fix.

* Fix rebase

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>

* show the last move of tictactoe by capital letters

* add multi-agent tutorial

* fix link

* Standardized behavior of Batch.cat and misc code refactor (#137)

* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc

* write tutorials to specify the standard of Batch (#142)

* add doc for len exceptions

* doc move; unify is_scalar_value function

* remove some issubclass check

* bugfix for shape of Batch(a=1)

* keep moving doc

* keep writing batch tutorial

* draft version of Batch tutorial done

* improving doc

* keep improving doc

* batch tutorial done

* rename _is_number

* rename _is_scalar

* shape property do not raise exception

* restore some doc string

* grammarly [ci skip]

* grammarly + fix warning of building docs

* polish docs

* trim and re-arrange batch tutorial

* go straight to the point

* minor fix for batch doc

* add shape / len in basic usage

* keep improving tutorial

* unify _to_array_with_correct_type to remove duplicate code

* delegate type convertion to Batch.__init__

* further delegate type convertion to Batch.__init__

* bugfix for setattr

* add a _parse_value function

* remove dummy function call

* polish docs

Co-authored-by: Trinkle23897 <463003665@qq.com>

* bugfix for mapolicy

* pretty code

* remove debug code; remove condense

* doc fix

* check before get_agents in tutorials/tictactoe

* tutorial

* fix

* minor fix for batch doc

* minor polish

* faster test_ttt

* improve tic-tac-toe environment

* change default epoch and step-per-epoch for tic-tac-toe

* fix mapolicy

* minor polish for mapolicy

* 90% to 80% (need to change the tutorial)

* win rate

* show step number at board

* simplify mapolicy

* minor polish for mapolicy

* remove MADQN

* fix pep8

* change legal_actions to mask (need to update docs)

* simplify maenv

* fix typo

* move basevecenv to single file

* separate RandomAgent

* update docs

* grammarly

* fix pep8

* win rate typo

* format in cheatsheet

* use bool mask directly

* update doc for boolean mask

Co-authored-by: Trinkle23897 <463003665@qq.com>
Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com>
Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00

141 lines
5.8 KiB
Python

import numpy as np
from typing import Union, Optional, Dict, List
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
class MultiAgentPolicyManager(BasePolicy):
"""This multi-agent policy manager accepts a list of
:class:`~tianshou.policy.BasePolicy`. It dispatches the batch data to each
of these policies when the "forward" is called. The same as "process_fn"
and "learn": it splits the data and feeds them to each policy. A figure in
:ref:`marl_example` can help you better understand this procedure.
"""
def __init__(self, policies: List[BasePolicy]):
super().__init__()
self.policies = policies
for i, policy in enumerate(policies):
# agent_id 0 is reserved for the environment proxy
# (this MultiAgentPolicyManager)
policy.set_agent_id(i + 1)
def replace_policy(self, policy, agent_id):
"""Replace the "agent_id"th policy in this manager."""
self.policies[agent_id - 1] = policy
policy.set_agent_id(agent_id)
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
"""Save original multi-dimensional rew in "save_rew", set rew to the
reward of each agent during their ``process_fn``, and restore the
original reward afterwards.
"""
results = {}
# reward can be empty Batch (after initial reset) or nparray.
has_rew = isinstance(buffer.rew, np.ndarray)
if has_rew: # save the original reward in save_rew
save_rew, buffer.rew = buffer.rew, Batch()
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
if len(agent_index) == 0:
results[f'agent_{policy.agent_id}'] = Batch()
continue
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
if has_rew:
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
buffer.rew = save_rew[:, policy.agent_id - 1]
results[f'agent_{policy.agent_id}'] = \
policy.process_fn(tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew
buffer.rew = save_rew
return Batch(results)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch]] = None,
**kwargs) -> Batch:
""":param state: if None, it means all agents have no state. If not
None, it should contain keys of "agent_1", "agent_2", ...
:return: a Batch with the following contents:
::
{
"act": actions corresponding to the input
"state":{
"agent_1": output state of agent_1's policy for the state
"agent_2": xxx
...
"agent_n": xxx}
"out":{
"agent_1": output of agent_1's policy for the input
"agent_2": xxx
...
"agent_n": xxx}
}
"""
results = []
for policy in self.policies:
# This part of code is difficult to understand.
# Let's follow an example with two agents
# batch.obs.agent_id is [1, 2, 1, 2, 1, 2] (with batch_size == 6)
# each agent plays for three transitions
# agent_index for agent 1 is [0, 2, 4]
# agent_index for agent 2 is [1, 3, 5]
# we separate the transition of each agent according to agent_id
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
if len(agent_index) == 0:
# (has_data, agent_index, out, act, state)
results.append((False, None, Batch(), None, Batch()))
continue
tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
out = policy(batch=tmp_batch, state=None if state is None
else state["agent_" + str(policy.agent_id)],
**kwargs)
act = out.act
each_state = out.state \
if (hasattr(out, 'state') and out.state is not None) \
else Batch()
results.append((True, agent_index, out, act, each_state))
holder = Batch.cat([{'act': act} for
(has_data, agent_index, out, act, each_state)
in results if has_data])
state_dict, out_dict = {}, {}
for policy, (has_data, agent_index, out, act, state) in \
zip(self.policies, results):
if has_data:
holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state
out_dict["agent_" + str(policy.agent_id)] = out
holder["out"] = out_dict
holder["state"] = state_dict
return holder
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
""":return: a dict with the following contents:
::
{
"agent_1/item1": item 1 of agent_1's policy.learn output
"agent_1/item2": item 2 of agent_1's policy.learn output
"agent_2/xxx": xxx
...
"agent_n/xxx": xxx
}
"""
results = {}
for policy in self.policies:
data = batch[f'agent_{policy.agent_id}']
if not data.is_empty():
out = policy.learn(batch=data, **kwargs)
for k, v in out.items():
results["agent_" + str(policy.agent_id) + '/' + k] = v
return results