Closes #947 This removes all kwargs from all policy constructors. While doing that, I also improved several names and added a whole lot of TODOs. ## Functional changes: 1. Added possibility to pass None as `critic2` and `critic2_optim`. In fact, the default behavior then should cover the absolute majority of cases 2. Added a function called `clone_optimizer` as a temporary measure to support passing `critic2_optim=None` ## Breaking changes: 1. `action_space` is no longer optional. In fact, it already was non-optional, as there was a ValueError in BasePolicy.init. So now several examples were fixed to reflect that 2. `reward_normalization` removed from DDPG and children. It was never allowed to pass it as `True` there, an error would have been raised in `compute_n_step_reward`. Now I removed it from the interface 3. renamed `critic1` and similar to `critic`, in order to have uniform interfaces. Note that the `critic` in DDPG was optional for the sole reason that child classes used `critic1`. I removed this optionality (DDPG can't do anything with `critic=None`) 4. Several renamings of fields (mostly private to public, so backwards compatible) ## Additional changes: 1. Removed type and default declaration from docstring. This kind of duplication is really not necessary 2. Policy constructors are now only called using named arguments, not a fragile mixture of positional and named as before 5. Minor beautifications in typing and code 6. Generally shortened docstrings and made them uniform across all policies (hopefully) ## Comment: With these changes, several problems in tianshou's inheritance hierarchy become more apparent. I tried highlighting them for future work. --------- Co-authored-by: Dominik Jain <d.jain@appliedai.de>
		
			
				
	
	
		
			233 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# see issue #322 for detail
 | 
						|
 | 
						|
import copy
 | 
						|
from collections import Counter
 | 
						|
 | 
						|
import gymnasium as gym
 | 
						|
import numpy as np
 | 
						|
from gymnasium.spaces import Box
 | 
						|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
 | 
						|
 | 
						|
from tianshou.data import Batch, Collector
 | 
						|
from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv
 | 
						|
from tianshou.policy import BasePolicy
 | 
						|
 | 
						|
 | 
						|
class DummyDataset(Dataset):
 | 
						|
    def __init__(self, length):
 | 
						|
        self.length = length
 | 
						|
        self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
 | 
						|
 | 
						|
    def __getitem__(self, index):
 | 
						|
        assert 0 <= index < self.length
 | 
						|
        return index, self.episodes[index]
 | 
						|
 | 
						|
    def __len__(self):
 | 
						|
        return self.length
 | 
						|
 | 
						|
 | 
						|
class FiniteEnv(gym.Env):
 | 
						|
    def __init__(self, dataset, num_replicas, rank):
 | 
						|
        self.dataset = dataset
 | 
						|
        self.num_replicas = num_replicas
 | 
						|
        self.rank = rank
 | 
						|
        self.loader = DataLoader(
 | 
						|
            dataset,
 | 
						|
            sampler=DistributedSampler(dataset, num_replicas, rank),
 | 
						|
            batch_size=None,
 | 
						|
        )
 | 
						|
        self.iterator = None
 | 
						|
 | 
						|
    def reset(self):
 | 
						|
        if self.iterator is None:
 | 
						|
            self.iterator = iter(self.loader)
 | 
						|
        try:
 | 
						|
            self.current_sample, self.step_count = next(self.iterator)
 | 
						|
            self.current_step = 0
 | 
						|
            return self.current_sample, {}
 | 
						|
        except StopIteration:
 | 
						|
            self.iterator = None
 | 
						|
            return None, {}
 | 
						|
 | 
						|
    def step(self, action):
 | 
						|
        self.current_step += 1
 | 
						|
        assert self.current_step <= self.step_count
 | 
						|
        return (
 | 
						|
            0,
 | 
						|
            1.0,
 | 
						|
            self.current_step >= self.step_count,
 | 
						|
            False,
 | 
						|
            {"sample": self.current_sample, "action": action, "metric": 2.0},
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class FiniteVectorEnv(BaseVectorEnv):
 | 
						|
    def __init__(self, env_fns, **kwargs):
 | 
						|
        super().__init__(env_fns, **kwargs)
 | 
						|
        self._alive_env_ids = set()
 | 
						|
        self._reset_alive_envs()
 | 
						|
        self._default_obs = self._default_info = None
 | 
						|
 | 
						|
    def _reset_alive_envs(self):
 | 
						|
        if not self._alive_env_ids:
 | 
						|
            # starting or running out
 | 
						|
            self._alive_env_ids = set(range(self.env_num))
 | 
						|
 | 
						|
    # to workaround with tianshou's buffer and batch
 | 
						|
    def _set_default_obs(self, obs):
 | 
						|
        if obs is not None and self._default_obs is None:
 | 
						|
            self._default_obs = copy.deepcopy(obs)
 | 
						|
 | 
						|
    def _set_default_info(self, info):
 | 
						|
        if info is not None and self._default_info is None:
 | 
						|
            self._default_info = copy.deepcopy(info)
 | 
						|
 | 
						|
    def _get_default_obs(self):
 | 
						|
        return copy.deepcopy(self._default_obs)
 | 
						|
 | 
						|
    def _get_default_info(self):
 | 
						|
        return copy.deepcopy(self._default_info)
 | 
						|
 | 
						|
    # END
 | 
						|
 | 
						|
    def reset(self, id=None):
 | 
						|
        id = self._wrap_id(id)
 | 
						|
        self._reset_alive_envs()
 | 
						|
 | 
						|
        # ask super to reset alive envs and remap to current index
 | 
						|
        request_id = list(filter(lambda i: i in self._alive_env_ids, id))
 | 
						|
        obs = [None] * len(id)
 | 
						|
        infos = [None] * len(id)
 | 
						|
        id2idx = {i: k for k, i in enumerate(id)}
 | 
						|
        if request_id:
 | 
						|
            for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
 | 
						|
                obs[id2idx[k]] = o
 | 
						|
                infos[id2idx[k]] = info
 | 
						|
        for i, o in zip(id, obs, strict=True):
 | 
						|
            if o is None and i in self._alive_env_ids:
 | 
						|
                self._alive_env_ids.remove(i)
 | 
						|
 | 
						|
        # fill empty observation with default(fake) observation
 | 
						|
        for o in obs:
 | 
						|
            self._set_default_obs(o)
 | 
						|
 | 
						|
        for i in range(len(obs)):
 | 
						|
            if obs[i] is None:
 | 
						|
                obs[i] = self._get_default_obs()
 | 
						|
            if infos[i] is None:
 | 
						|
                infos[i] = self._get_default_info()
 | 
						|
 | 
						|
        if not self._alive_env_ids:
 | 
						|
            self.reset()
 | 
						|
            raise StopIteration
 | 
						|
 | 
						|
        return np.stack(obs), infos
 | 
						|
 | 
						|
    def step(self, action, id=None):
 | 
						|
        id = self._wrap_id(id)
 | 
						|
        id2idx = {i: k for k, i in enumerate(id)}
 | 
						|
        request_id = list(filter(lambda i: i in self._alive_env_ids, id))
 | 
						|
        result = [[None, 0.0, False, False, None] for _ in range(len(id))]
 | 
						|
 | 
						|
        # ask super to step alive envs and remap to current index
 | 
						|
        if request_id:
 | 
						|
            valid_act = np.stack([action[id2idx[i]] for i in request_id])
 | 
						|
            for i, r in zip(
 | 
						|
                request_id,
 | 
						|
                zip(*super().step(valid_act, request_id), strict=True),
 | 
						|
                strict=True,
 | 
						|
            ):
 | 
						|
                result[id2idx[i]] = r
 | 
						|
 | 
						|
        # logging
 | 
						|
        for i, r in zip(id, result, strict=True):
 | 
						|
            if i in self._alive_env_ids:
 | 
						|
                self.tracker.log(*r)
 | 
						|
 | 
						|
        # fill empty observation/info with default(fake)
 | 
						|
        for _, __, ___, ____, i in result:
 | 
						|
            self._set_default_info(i)
 | 
						|
        for i in range(len(result)):
 | 
						|
            if result[i][0] is None:
 | 
						|
                result[i][0] = self._get_default_obs()
 | 
						|
            if result[i][-1] is None:
 | 
						|
                result[i][-1] = self._get_default_info()
 | 
						|
 | 
						|
        return list(map(np.stack, zip(*result, strict=True)))
 | 
						|
 | 
						|
 | 
						|
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class AnyPolicy(BasePolicy):
 | 
						|
    def __init__(self):
 | 
						|
        super().__init__(action_space=Box(-1, 1, (1,)))
 | 
						|
 | 
						|
    def forward(self, batch, state=None):
 | 
						|
        return Batch(act=np.stack([1] * len(batch)))
 | 
						|
 | 
						|
    def learn(self, batch):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
def _finite_env_factory(dataset, num_replicas, rank):
 | 
						|
    return lambda: FiniteEnv(dataset, num_replicas, rank)
 | 
						|
 | 
						|
 | 
						|
class MetricTracker:
 | 
						|
    def __init__(self):
 | 
						|
        self.counter = Counter()
 | 
						|
        self.finished = set()
 | 
						|
 | 
						|
    def log(self, obs, rew, terminated, truncated, info):
 | 
						|
        assert rew == 1.0
 | 
						|
        done = terminated or truncated
 | 
						|
        index = info["sample"]
 | 
						|
        if done:
 | 
						|
            assert index not in self.finished
 | 
						|
            self.finished.add(index)
 | 
						|
        self.counter[index] += 1
 | 
						|
 | 
						|
    def validate(self):
 | 
						|
        assert len(self.finished) == 100
 | 
						|
        for k, v in self.counter.items():
 | 
						|
            assert v == k * 3 % 5 + 1
 | 
						|
 | 
						|
 | 
						|
def test_finite_dummy_vector_env():
 | 
						|
    dataset = DummyDataset(100)
 | 
						|
    envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
 | 
						|
    policy = AnyPolicy()
 | 
						|
    test_collector = Collector(policy, envs, exploration_noise=True)
 | 
						|
 | 
						|
    for _ in range(3):
 | 
						|
        envs.tracker = MetricTracker()
 | 
						|
        try:
 | 
						|
            test_collector.collect(n_step=10**18)
 | 
						|
        except StopIteration:
 | 
						|
            envs.tracker.validate()
 | 
						|
 | 
						|
 | 
						|
def test_finite_subproc_vector_env():
 | 
						|
    dataset = DummyDataset(100)
 | 
						|
    envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
 | 
						|
    policy = AnyPolicy()
 | 
						|
    test_collector = Collector(policy, envs, exploration_noise=True)
 | 
						|
 | 
						|
    for _ in range(3):
 | 
						|
        envs.tracker = MetricTracker()
 | 
						|
        try:
 | 
						|
            test_collector.collect(n_step=10**18)
 | 
						|
        except StopIteration:
 | 
						|
            envs.tracker.validate()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_finite_dummy_vector_env()
 | 
						|
    test_finite_subproc_vector_env()
 |