Closes #914 Additional changes: - Deprecate python below 11 - Remove 3rd party and throughput tests. This simplifies install and test pipeline - Remove gym compatibility and shimmy - Format with 3.11 conventions. In particular, add `zip(..., strict=True/False)` where possible Since the additional tests and gym were complicating the CI pipeline (flaky and dist-dependent), it didn't make sense to work on fixing the current tests in this PR to then just delete them in the next one. So this PR changes the build and removes these tests at the same time.
229 lines
6.8 KiB
Python
229 lines
6.8 KiB
Python
# see issue #322 for detail
|
|
|
|
import copy
|
|
from collections import Counter
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
|
|
|
from tianshou.data import Batch, Collector
|
|
from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
class DummyDataset(Dataset):
|
|
def __init__(self, length):
|
|
self.length = length
|
|
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
|
|
|
|
def __getitem__(self, index):
|
|
assert 0 <= index < self.length
|
|
return index, self.episodes[index]
|
|
|
|
def __len__(self):
|
|
return self.length
|
|
|
|
|
|
class FiniteEnv(gym.Env):
|
|
def __init__(self, dataset, num_replicas, rank):
|
|
self.dataset = dataset
|
|
self.num_replicas = num_replicas
|
|
self.rank = rank
|
|
self.loader = DataLoader(
|
|
dataset,
|
|
sampler=DistributedSampler(dataset, num_replicas, rank),
|
|
batch_size=None,
|
|
)
|
|
self.iterator = None
|
|
|
|
def reset(self):
|
|
if self.iterator is None:
|
|
self.iterator = iter(self.loader)
|
|
try:
|
|
self.current_sample, self.step_count = next(self.iterator)
|
|
self.current_step = 0
|
|
return self.current_sample, {}
|
|
except StopIteration:
|
|
self.iterator = None
|
|
return None, {}
|
|
|
|
def step(self, action):
|
|
self.current_step += 1
|
|
assert self.current_step <= self.step_count
|
|
return (
|
|
0,
|
|
1.0,
|
|
self.current_step >= self.step_count,
|
|
False,
|
|
{"sample": self.current_sample, "action": action, "metric": 2.0},
|
|
)
|
|
|
|
|
|
class FiniteVectorEnv(BaseVectorEnv):
|
|
def __init__(self, env_fns, **kwargs):
|
|
super().__init__(env_fns, **kwargs)
|
|
self._alive_env_ids = set()
|
|
self._reset_alive_envs()
|
|
self._default_obs = self._default_info = None
|
|
|
|
def _reset_alive_envs(self):
|
|
if not self._alive_env_ids:
|
|
# starting or running out
|
|
self._alive_env_ids = set(range(self.env_num))
|
|
|
|
# to workaround with tianshou's buffer and batch
|
|
def _set_default_obs(self, obs):
|
|
if obs is not None and self._default_obs is None:
|
|
self._default_obs = copy.deepcopy(obs)
|
|
|
|
def _set_default_info(self, info):
|
|
if info is not None and self._default_info is None:
|
|
self._default_info = copy.deepcopy(info)
|
|
|
|
def _get_default_obs(self):
|
|
return copy.deepcopy(self._default_obs)
|
|
|
|
def _get_default_info(self):
|
|
return copy.deepcopy(self._default_info)
|
|
|
|
# END
|
|
|
|
def reset(self, id=None):
|
|
id = self._wrap_id(id)
|
|
self._reset_alive_envs()
|
|
|
|
# ask super to reset alive envs and remap to current index
|
|
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
|
obs = [None] * len(id)
|
|
infos = [None] * len(id)
|
|
id2idx = {i: k for k, i in enumerate(id)}
|
|
if request_id:
|
|
for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
|
|
obs[id2idx[k]] = o
|
|
infos[id2idx[k]] = info
|
|
for i, o in zip(id, obs, strict=True):
|
|
if o is None and i in self._alive_env_ids:
|
|
self._alive_env_ids.remove(i)
|
|
|
|
# fill empty observation with default(fake) observation
|
|
for o in obs:
|
|
self._set_default_obs(o)
|
|
|
|
for i in range(len(obs)):
|
|
if obs[i] is None:
|
|
obs[i] = self._get_default_obs()
|
|
if infos[i] is None:
|
|
infos[i] = self._get_default_info()
|
|
|
|
if not self._alive_env_ids:
|
|
self.reset()
|
|
raise StopIteration
|
|
|
|
return np.stack(obs), infos
|
|
|
|
def step(self, action, id=None):
|
|
id = self._wrap_id(id)
|
|
id2idx = {i: k for k, i in enumerate(id)}
|
|
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
|
result = [[None, 0.0, False, False, None] for _ in range(len(id))]
|
|
|
|
# ask super to step alive envs and remap to current index
|
|
if request_id:
|
|
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
|
for i, r in zip(
|
|
request_id,
|
|
zip(*super().step(valid_act, request_id), strict=True),
|
|
strict=True,
|
|
):
|
|
result[id2idx[i]] = r
|
|
|
|
# logging
|
|
for i, r in zip(id, result, strict=True):
|
|
if i in self._alive_env_ids:
|
|
self.tracker.log(*r)
|
|
|
|
# fill empty observation/info with default(fake)
|
|
for _, __, ___, ____, i in result:
|
|
self._set_default_info(i)
|
|
for i in range(len(result)):
|
|
if result[i][0] is None:
|
|
result[i][0] = self._get_default_obs()
|
|
if result[i][-1] is None:
|
|
result[i][-1] = self._get_default_info()
|
|
|
|
return list(map(np.stack, zip(*result, strict=True)))
|
|
|
|
|
|
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
|
pass
|
|
|
|
|
|
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
|
|
pass
|
|
|
|
|
|
class AnyPolicy(BasePolicy):
|
|
def forward(self, batch, state=None):
|
|
return Batch(act=np.stack([1] * len(batch)))
|
|
|
|
def learn(self, batch):
|
|
pass
|
|
|
|
|
|
def _finite_env_factory(dataset, num_replicas, rank):
|
|
return lambda: FiniteEnv(dataset, num_replicas, rank)
|
|
|
|
|
|
class MetricTracker:
|
|
def __init__(self):
|
|
self.counter = Counter()
|
|
self.finished = set()
|
|
|
|
def log(self, obs, rew, terminated, truncated, info):
|
|
assert rew == 1.0
|
|
done = terminated or truncated
|
|
index = info["sample"]
|
|
if done:
|
|
assert index not in self.finished
|
|
self.finished.add(index)
|
|
self.counter[index] += 1
|
|
|
|
def validate(self):
|
|
assert len(self.finished) == 100
|
|
for k, v in self.counter.items():
|
|
assert v == k * 3 % 5 + 1
|
|
|
|
|
|
def test_finite_dummy_vector_env():
|
|
dataset = DummyDataset(100)
|
|
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
|
|
policy = AnyPolicy()
|
|
test_collector = Collector(policy, envs, exploration_noise=True)
|
|
|
|
for _ in range(3):
|
|
envs.tracker = MetricTracker()
|
|
try:
|
|
test_collector.collect(n_step=10**18)
|
|
except StopIteration:
|
|
envs.tracker.validate()
|
|
|
|
|
|
def test_finite_subproc_vector_env():
|
|
dataset = DummyDataset(100)
|
|
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
|
|
policy = AnyPolicy()
|
|
test_collector = Collector(policy, envs, exploration_noise=True)
|
|
|
|
for _ in range(3):
|
|
envs.tracker = MetricTracker()
|
|
try:
|
|
test_collector.collect(n_step=10**18)
|
|
except StopIteration:
|
|
envs.tracker.validate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_finite_dummy_vector_env()
|
|
test_finite_subproc_vector_env()
|