Update Multi-agent RL docs, upgrade pettingzoo (#595)
* update multi-agent docs, upgrade pettingzoo * avoid pettingzoo deprecation warning * fix pistonball tests * codestyle
This commit is contained in:
parent
18277497ed
commit
6fc6857812
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@ -46,7 +46,7 @@ def get_extras_require() -> str:
|
|||||||
"doc8",
|
"doc8",
|
||||||
"scipy",
|
"scipy",
|
||||||
"pillow",
|
"pillow",
|
||||||
"pettingzoo>=1.12",
|
"pettingzoo>=1.17",
|
||||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||||
"nni>=2.3",
|
"nni>=2.3",
|
||||||
|
@ -176,7 +176,7 @@ def train_agent(
|
|||||||
def watch(
|
def watch(
|
||||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
env = get_env()
|
env = DummyVectorEnv([get_env])
|
||||||
policy.eval()
|
policy.eval()
|
||||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||||
collector = Collector(policy, env, exploration_noise=True)
|
collector = Collector(policy, env, exploration_noise=True)
|
||||||
|
@ -268,7 +268,7 @@ def train_agent(
|
|||||||
def watch(
|
def watch(
|
||||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
env = get_env()
|
env = DummyVectorEnv([get_env])
|
||||||
policy.eval()
|
policy.eval()
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
@ -229,7 +229,7 @@ def watch(
|
|||||||
agent_learn: Optional[BasePolicy] = None,
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
env = get_env()
|
env = DummyVectorEnv([get_env])
|
||||||
policy, optim, agents = get_agents(
|
policy, optim, agents = get_agents(
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||||
)
|
)
|
||||||
|
7
tianshou/env/pettingzoo_env.py
vendored
7
tianshou/env/pettingzoo_env.py
vendored
@ -32,17 +32,14 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
self.agent_idx = {}
|
self.agent_idx = {}
|
||||||
for i, agent_id in enumerate(self.agents):
|
for i, agent_id in enumerate(self.agents):
|
||||||
self.agent_idx[agent_id] = i
|
self.agent_idx[agent_id] = i
|
||||||
# Get dictionaries of obs_spaces and act_spaces
|
|
||||||
self.observation_spaces = self.env.observation_spaces
|
|
||||||
self.action_spaces = self.env.action_spaces
|
|
||||||
|
|
||||||
self.rewards = [0] * len(self.agents)
|
self.rewards = [0] * len(self.agents)
|
||||||
|
|
||||||
# Get first observation space, assuming all agents have equal space
|
# Get first observation space, assuming all agents have equal space
|
||||||
self.observation_space: Any = self.observation_space(self.agents[0])
|
self.observation_space: Any = self.env.observation_space(self.agents[0])
|
||||||
|
|
||||||
# Get first action space, assuming all agents have equal space
|
# Get first action space, assuming all agents have equal space
|
||||||
self.action_space: Any = self.action_space(self.agents[0])
|
self.action_space: Any = self.env.action_space(self.agents[0])
|
||||||
|
|
||||||
assert all(self.env.observation_space(agent) == self.observation_space
|
assert all(self.env.observation_space(agent) == self.observation_space
|
||||||
for agent in self.agents), \
|
for agent in self.agents), \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user