Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			233 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # see issue #322 for detail
 | |
| 
 | |
| import copy
 | |
| from collections import Counter
 | |
| 
 | |
| import gymnasium as gym
 | |
| import numpy as np
 | |
| from gymnasium.spaces import Box
 | |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler
 | |
| 
 | |
| from tianshou.data import Batch, Collector
 | |
| from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv
 | |
| from tianshou.policy import BasePolicy
 | |
| 
 | |
| 
 | |
| class DummyDataset(Dataset):
 | |
|     def __init__(self, length):
 | |
|         self.length = length
 | |
|         self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         assert 0 <= index < self.length
 | |
|         return index, self.episodes[index]
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self.length
 | |
| 
 | |
| 
 | |
| class FiniteEnv(gym.Env):
 | |
|     def __init__(self, dataset, num_replicas, rank):
 | |
|         self.dataset = dataset
 | |
|         self.num_replicas = num_replicas
 | |
|         self.rank = rank
 | |
|         self.loader = DataLoader(
 | |
|             dataset,
 | |
|             sampler=DistributedSampler(dataset, num_replicas, rank),
 | |
|             batch_size=None,
 | |
|         )
 | |
|         self.iterator = None
 | |
| 
 | |
|     def reset(self):
 | |
|         if self.iterator is None:
 | |
|             self.iterator = iter(self.loader)
 | |
|         try:
 | |
|             self.current_sample, self.step_count = next(self.iterator)
 | |
|             self.current_step = 0
 | |
|             return self.current_sample, {}
 | |
|         except StopIteration:
 | |
|             self.iterator = None
 | |
|             return None, {}
 | |
| 
 | |
|     def step(self, action):
 | |
|         self.current_step += 1
 | |
|         assert self.current_step <= self.step_count
 | |
|         return (
 | |
|             0,
 | |
|             1.0,
 | |
|             self.current_step >= self.step_count,
 | |
|             False,
 | |
|             {"sample": self.current_sample, "action": action, "metric": 2.0},
 | |
|         )
 | |
| 
 | |
| 
 | |
| class FiniteVectorEnv(BaseVectorEnv):
 | |
|     def __init__(self, env_fns, **kwargs):
 | |
|         super().__init__(env_fns, **kwargs)
 | |
|         self._alive_env_ids = set()
 | |
|         self._reset_alive_envs()
 | |
|         self._default_obs = self._default_info = None
 | |
| 
 | |
|     def _reset_alive_envs(self):
 | |
|         if not self._alive_env_ids:
 | |
|             # starting or running out
 | |
|             self._alive_env_ids = set(range(self.env_num))
 | |
| 
 | |
|     # to workaround with tianshou's buffer and batch
 | |
|     def _set_default_obs(self, obs):
 | |
|         if obs is not None and self._default_obs is None:
 | |
|             self._default_obs = copy.deepcopy(obs)
 | |
| 
 | |
|     def _set_default_info(self, info):
 | |
|         if info is not None and self._default_info is None:
 | |
|             self._default_info = copy.deepcopy(info)
 | |
| 
 | |
|     def _get_default_obs(self):
 | |
|         return copy.deepcopy(self._default_obs)
 | |
| 
 | |
|     def _get_default_info(self):
 | |
|         return copy.deepcopy(self._default_info)
 | |
| 
 | |
|     # END
 | |
| 
 | |
|     def reset(self, id=None):
 | |
|         id = self._wrap_id(id)
 | |
|         self._reset_alive_envs()
 | |
| 
 | |
|         # ask super to reset alive envs and remap to current index
 | |
|         request_id = list(filter(lambda i: i in self._alive_env_ids, id))
 | |
|         obs = [None] * len(id)
 | |
|         infos = [None] * len(id)
 | |
|         id2idx = {i: k for k, i in enumerate(id)}
 | |
|         if request_id:
 | |
|             for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
 | |
|                 obs[id2idx[k]] = o
 | |
|                 infos[id2idx[k]] = info
 | |
|         for i, o in zip(id, obs, strict=True):
 | |
|             if o is None and i in self._alive_env_ids:
 | |
|                 self._alive_env_ids.remove(i)
 | |
| 
 | |
|         # fill empty observation with default(fake) observation
 | |
|         for o in obs:
 | |
|             self._set_default_obs(o)
 | |
| 
 | |
|         for i in range(len(obs)):
 | |
|             if obs[i] is None:
 | |
|                 obs[i] = self._get_default_obs()
 | |
|             if infos[i] is None:
 | |
|                 infos[i] = self._get_default_info()
 | |
| 
 | |
|         if not self._alive_env_ids:
 | |
|             self.reset()
 | |
|             raise StopIteration
 | |
| 
 | |
|         return np.stack(obs), infos
 | |
| 
 | |
|     def step(self, action, id=None):
 | |
|         id = self._wrap_id(id)
 | |
|         id2idx = {i: k for k, i in enumerate(id)}
 | |
|         request_id = list(filter(lambda i: i in self._alive_env_ids, id))
 | |
|         result = [[None, 0.0, False, False, None] for _ in range(len(id))]
 | |
| 
 | |
|         # ask super to step alive envs and remap to current index
 | |
|         if request_id:
 | |
|             valid_act = np.stack([action[id2idx[i]] for i in request_id])
 | |
|             for i, r in zip(
 | |
|                 request_id,
 | |
|                 zip(*super().step(valid_act, request_id), strict=True),
 | |
|                 strict=True,
 | |
|             ):
 | |
|                 result[id2idx[i]] = r
 | |
| 
 | |
|         # logging
 | |
|         for i, r in zip(id, result, strict=True):
 | |
|             if i in self._alive_env_ids:
 | |
|                 self.tracker.log(*r)
 | |
| 
 | |
|         # fill empty observation/info with default(fake)
 | |
|         for _, __, ___, ____, i in result:
 | |
|             self._set_default_info(i)
 | |
|         for i in range(len(result)):
 | |
|             if result[i][0] is None:
 | |
|                 result[i][0] = self._get_default_obs()
 | |
|             if result[i][-1] is None:
 | |
|                 result[i][-1] = self._get_default_info()
 | |
| 
 | |
|         return list(map(np.stack, zip(*result, strict=True)))
 | |
| 
 | |
| 
 | |
| class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class AnyPolicy(BasePolicy):
 | |
|     def __init__(self):
 | |
|         super().__init__(action_space=Box(-1, 1, (1,)))
 | |
| 
 | |
|     def forward(self, batch, state=None):
 | |
|         return Batch(act=np.stack([1] * len(batch)))
 | |
| 
 | |
|     def learn(self, batch):
 | |
|         pass
 | |
| 
 | |
| 
 | |
| def _finite_env_factory(dataset, num_replicas, rank):
 | |
|     return lambda: FiniteEnv(dataset, num_replicas, rank)
 | |
| 
 | |
| 
 | |
| class MetricTracker:
 | |
|     def __init__(self):
 | |
|         self.counter = Counter()
 | |
|         self.finished = set()
 | |
| 
 | |
|     def log(self, obs, rew, terminated, truncated, info):
 | |
|         assert rew == 1.0
 | |
|         done = terminated or truncated
 | |
|         index = info["sample"]
 | |
|         if done:
 | |
|             assert index not in self.finished
 | |
|             self.finished.add(index)
 | |
|         self.counter[index] += 1
 | |
| 
 | |
|     def validate(self):
 | |
|         assert len(self.finished) == 100
 | |
|         for k, v in self.counter.items():
 | |
|             assert v == k * 3 % 5 + 1
 | |
| 
 | |
| 
 | |
| def test_finite_dummy_vector_env():
 | |
|     dataset = DummyDataset(100)
 | |
|     envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
 | |
|     policy = AnyPolicy()
 | |
|     test_collector = Collector(policy, envs, exploration_noise=True)
 | |
| 
 | |
|     for _ in range(3):
 | |
|         envs.tracker = MetricTracker()
 | |
|         try:
 | |
|             test_collector.collect(n_step=10**18)
 | |
|         except StopIteration:
 | |
|             envs.tracker.validate()
 | |
| 
 | |
| 
 | |
| def test_finite_subproc_vector_env():
 | |
|     dataset = DummyDataset(100)
 | |
|     envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
 | |
|     policy = AnyPolicy()
 | |
|     test_collector = Collector(policy, envs, exploration_noise=True)
 | |
| 
 | |
|     for _ in range(3):
 | |
|         envs.tracker = MetricTracker()
 | |
|         try:
 | |
|             test_collector.collect(n_step=10**18)
 | |
|         except StopIteration:
 | |
|             envs.tracker.validate()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     test_finite_dummy_vector_env()
 | |
|     test_finite_subproc_vector_env()
 |