From 73600edc582b31feeaed2db8a7c42e2ff06dfc44 Mon Sep 17 00:00:00 2001 From: Oren Zeev-Ben-Mordehai Date: Mon, 13 Mar 2023 01:58:09 +0100 Subject: [PATCH] fix a bug in batch._is_batch_set (#825) - [ ] I have marked all applicable categories: + [x] exception-raising fix + [ ] algorithm implementation fix + [ ] documentation modification + [ ] new feature - [ ] I have reformatted the code using `make format` (**required**) - [ ] I have checked the code using `make commit-checks` (**required**) - [ ] If applicable, I have mentioned the relevant/related issue(s) - [ ] If applicable, I have listed every items in this Pull Request below I'm developing a new PettingZoo environment. It is a two players turns board game. ``` obs_space = dict( board = gym.spaces.MultiBinary([8, 8]), player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2), other_player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2) ) self._observation_space = gym.spaces.Dict(spaces=obs_space) self._action_space = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2) ... # this cache ensures that same space object is returned for the same agent # allows action space seeding to work as expected @functools.lru_cache(maxsize=None) def observation_space(self, agent): # gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/ return self._observation_space @functools.lru_cache(maxsize=None) def action_space(self, agent): return self._action_space ``` My test is: ``` def test_with_tianshou(): action = None # env = gym.make('qwertyenv/CollectCoins-v0', pieces=['rock', 'rock']) env = CollectCoinsEnv(pieces=['rock', 'rock'], with_mask=True) def another_action_taken(action_taken): nonlocal action action = action_taken # Wrapping the original environment as to make sure a valid action will be taken. env = EnsureValidAction( env, env.check_action_valid, env.provide_alternative_valid_action, another_action_taken ) env = PettingZooEnv(env) policies = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()], env) env = DummyVectorEnv([lambda: env]) collector = Collector(policies, env) result = collector.collect(n_step=200, render=0.1) ``` I have also a wrapper that may be redundant as of Tianshou capability to action_mask, yet it is still part of the code: ``` from typing import TypeVar, Callable import gymnasium as gym from pettingzoo.utils.wrappers import BaseWrapper Action = TypeVar("Action") class ActionWrapper(BaseWrapper): def __init__(self, env: gym.Env): super().__init__(env) def step(self, action): action = self.action(action) self.env.step(action) def action(self, action): pass def render(self, *args, **kwargs): self.env.render(*args, **kwargs) class EnsureValidAction(ActionWrapper): """ A gym environment wrapper to help with the case that the agent wants to take invalid actions. For example consider a Chess game, where you let the action_space be any piece moving to any square on the board, but then when a wrong move is taken, instead of returing a big negative reward, you just take another action, this time a valid one. To make sure the learning algorithm is aware of the action taken, a callback should be provided. """ def __init__(self, env: gym.Env, check_action_valid: Callable[[Action], bool], provide_alternative_valid_action: Callable[[Action], Action], alternative_action_cb: Callable[[Action], None]): super().__init__(env) self.check_action_valid = check_action_valid self.provide_alternative_valid_action = provide_alternative_valid_action self.alternative_action_cb = alternative_action_cb def action(self, action: Action) -> Action: if self.check_action_valid(action): return action alternative_action = self.provide_alternative_valid_action(action) self.alternative_action_cb(alternative_action) return alternative_action ``` To make above work I had to patch a bit PettingZoo (opened a pull-request there), and a small patch here (this PR). Maybe I'm doing something wrong, yet I fail to see it. With my both fixes of PZ and of Tianshou, I have two tests, one of the environment by itself, and the other as of above. --- tianshou/data/batch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7fddda2..e204151 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -19,6 +19,8 @@ def _is_batch_set(obj: Any) -> bool: # "for element in obj" will just unpack the first dimension, # but obj.tolist() will flatten ndarray of objects # so do not use obj.tolist() + if obj.shape == (): + return False return obj.dtype == object and \ all(isinstance(element, (dict, Batch)) for element in obj) elif isinstance(obj, (list, tuple)):