Update Multi-agent RL docs, upgrade pettingzoo (#595)
* update multi-agent docs, upgrade pettingzoo * avoid pettingzoo deprecation warning * fix pistonball tests * codestyle
This commit is contained in:
parent
18277497ed
commit
6fc6857812
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
@ -46,7 +46,7 @@ def get_extras_require() -> str:
|
||||
"doc8",
|
||||
"scipy",
|
||||
"pillow",
|
||||
"pettingzoo>=1.12",
|
||||
"pettingzoo>=1.17",
|
||||
"pygame>=2.1.0", # pettingzoo test cases pistonball
|
||||
"pymunk>=6.2.1", # pettingzoo test cases pistonball
|
||||
"nni>=2.3",
|
||||
|
@ -176,7 +176,7 @@ def train_agent(
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = get_env()
|
||||
env = DummyVectorEnv([get_env])
|
||||
policy.eval()
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
|
@ -268,7 +268,7 @@ def train_agent(
|
||||
def watch(
|
||||
args: argparse.Namespace = get_args(), policy: Optional[BasePolicy] = None
|
||||
) -> None:
|
||||
env = get_env()
|
||||
env = DummyVectorEnv([get_env])
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
|
@ -229,7 +229,7 @@ def watch(
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = get_env()
|
||||
env = DummyVectorEnv([get_env])
|
||||
policy, optim, agents = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||
)
|
||||
|
7
tianshou/env/pettingzoo_env.py
vendored
7
tianshou/env/pettingzoo_env.py
vendored
@ -32,17 +32,14 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
self.agent_idx = {}
|
||||
for i, agent_id in enumerate(self.agents):
|
||||
self.agent_idx[agent_id] = i
|
||||
# Get dictionaries of obs_spaces and act_spaces
|
||||
self.observation_spaces = self.env.observation_spaces
|
||||
self.action_spaces = self.env.action_spaces
|
||||
|
||||
self.rewards = [0] * len(self.agents)
|
||||
|
||||
# Get first observation space, assuming all agents have equal space
|
||||
self.observation_space: Any = self.observation_space(self.agents[0])
|
||||
self.observation_space: Any = self.env.observation_space(self.agents[0])
|
||||
|
||||
# Get first action space, assuming all agents have equal space
|
||||
self.action_space: Any = self.action_space(self.agents[0])
|
||||
self.action_space: Any = self.env.action_space(self.agents[0])
|
||||
|
||||
assert all(self.env.observation_space(agent) == self.observation_space
|
||||
for agent in self.agents), \
|
||||
|
Loading…
x
Reference in New Issue
Block a user