Update Multi-agent RL docs, upgrade pettingzoo (#595)

* update multi-agent docs, upgrade pettingzoo

* avoid pettingzoo deprecation warning

* fix pistonball tests

* codestyle
This commit is contained in:
Yifei Cheng 2022-04-16 11:17:53 -04:00 committed by GitHub
parent 18277497ed
commit 6fc6857812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 442 additions and 475 deletions

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ def get_extras_require() -> str:
"doc8",
"scipy",
"pillow",
"pettingzoo>=1.12",
"pettingzoo>=1.17",
"pygame>=2.1.0", # pettingzoo test cases pistonball
"pymunk>=6.2.1", # pettingzoo test cases pistonball
"nni>=2.3",

View File

@ -176,7 +176,7 @@ def train_agent(
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
env = DummyVectorEnv([get_env])
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)

View File

@ -268,7 +268,7 @@ def train_agent(
def watch(
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
) -> None:
env = get_env()
env = DummyVectorEnv([get_env])
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)

View File

@ -229,7 +229,7 @@ def watch(
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = get_env()
env = DummyVectorEnv([get_env])
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)

View File

@ -32,17 +32,14 @@ class PettingZooEnv(AECEnv, ABC):
self.agent_idx = {}
for i, agent_id in enumerate(self.agents):
self.agent_idx[agent_id] = i
# Get dictionaries of obs_spaces and act_spaces
self.observation_spaces = self.env.observation_spaces
self.action_spaces = self.env.action_spaces
self.rewards = [0] * len(self.agents)
# Get first observation space, assuming all agents have equal space
self.observation_space: Any = self.observation_space(self.agents[0])
self.observation_space: Any = self.env.observation_space(self.agents[0])
# Get first action space, assuming all agents have equal space
self.action_space: Any = self.action_space(self.agents[0])
self.action_space: Any = self.env.action_space(self.agents[0])
assert all(self.env.observation_space(agent) == self.observation_space
for agent in self.agents), \