parent
3ac67d9974
commit
7db21f3df6
212
test/base/test_env_finite.py
Normal file
212
test/base/test_env_finite.py
Normal file
@ -0,0 +1,212 @@
|
||||
# see issue #322 for detail
|
||||
|
||||
import gym
|
||||
import copy
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Collector, Batch
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv
|
||||
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.episodes = [3 * i % 5 + 1 for i in range(self.length)]
|
||||
|
||||
def __getitem__(self, index):
|
||||
assert 0 <= index < self.length
|
||||
return index, self.episodes[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class FiniteEnv(gym.Env):
|
||||
def __init__(self, dataset, num_replicas, rank):
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.loader = DataLoader(
|
||||
dataset,
|
||||
sampler=DistributedSampler(dataset, num_replicas, rank),
|
||||
batch_size=None)
|
||||
self.iterator = None
|
||||
|
||||
def reset(self):
|
||||
if self.iterator is None:
|
||||
self.iterator = iter(self.loader)
|
||||
try:
|
||||
self.current_sample, self.step_count = next(self.iterator)
|
||||
self.current_step = 0
|
||||
return self.current_sample
|
||||
except StopIteration:
|
||||
self.iterator = None
|
||||
return None
|
||||
|
||||
def step(self, action):
|
||||
self.current_step += 1
|
||||
assert self.current_step <= self.step_count
|
||||
return 0, 1.0, self.current_step >= self.step_count, \
|
||||
{'sample': self.current_sample, 'action': action, 'metric': 2.0}
|
||||
|
||||
|
||||
class FiniteVectorEnv(BaseVectorEnv):
|
||||
def __init__(self, env_fns, **kwargs):
|
||||
super().__init__(env_fns, **kwargs)
|
||||
self._alive_env_ids = set()
|
||||
self._reset_alive_envs()
|
||||
self._default_obs = self._default_info = None
|
||||
|
||||
def _reset_alive_envs(self):
|
||||
if not self._alive_env_ids:
|
||||
# starting or running out
|
||||
self._alive_env_ids = set(range(self.env_num))
|
||||
|
||||
# to workaround with tianshou's buffer and batch
|
||||
def _set_default_obs(self, obs):
|
||||
if obs is not None and self._default_obs is None:
|
||||
self._default_obs = copy.deepcopy(obs)
|
||||
|
||||
def _set_default_info(self, info):
|
||||
if info is not None and self._default_info is None:
|
||||
self._default_info = copy.deepcopy(info)
|
||||
|
||||
def _get_default_obs(self):
|
||||
return copy.deepcopy(self._default_obs)
|
||||
|
||||
def _get_default_info(self):
|
||||
return copy.deepcopy(self._default_info)
|
||||
# END
|
||||
|
||||
def reset(self, id=None):
|
||||
id = self._wrap_id(id)
|
||||
self._reset_alive_envs()
|
||||
|
||||
# ask super to reset alive envs and remap to current index
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
obs = [None] * len(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
if request_id:
|
||||
for i, o in zip(request_id, super().reset(request_id)):
|
||||
obs[id2idx[i]] = o
|
||||
for i, o in zip(id, obs):
|
||||
if o is None and i in self._alive_env_ids:
|
||||
self._alive_env_ids.remove(i)
|
||||
|
||||
# fill empty observation with default(fake) observation
|
||||
for o in obs:
|
||||
self._set_default_obs(o)
|
||||
for i in range(len(obs)):
|
||||
if obs[i] is None:
|
||||
obs[i] = self._get_default_obs()
|
||||
|
||||
if not self._alive_env_ids:
|
||||
self.reset()
|
||||
raise StopIteration
|
||||
|
||||
return np.stack(obs)
|
||||
|
||||
def step(self, action, id=None):
|
||||
id = self._wrap_id(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
result = [[None, 0., False, None] for _ in range(len(id))]
|
||||
|
||||
# ask super to step alive envs and remap to current index
|
||||
if request_id:
|
||||
valid_act = np.stack([action[id2idx[i]] for i in request_id])
|
||||
for i, r in zip(request_id, zip(*super().step(valid_act, request_id))):
|
||||
result[id2idx[i]] = r
|
||||
|
||||
# logging
|
||||
for i, r in zip(id, result):
|
||||
if i in self._alive_env_ids:
|
||||
self.tracker.log(*r)
|
||||
|
||||
# fill empty observation/info with default(fake)
|
||||
for _, __, ___, i in result:
|
||||
self._set_default_info(i)
|
||||
for i in range(len(result)):
|
||||
if result[i][0] is None:
|
||||
result[i][0] = self._get_default_obs()
|
||||
if result[i][3] is None:
|
||||
result[i][3] = self._get_default_info()
|
||||
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
|
||||
class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv):
|
||||
pass
|
||||
|
||||
|
||||
class FiniteSubprocVectorEnv(FiniteVectorEnv, SubprocVectorEnv):
|
||||
pass
|
||||
|
||||
|
||||
class AnyPolicy(BasePolicy):
|
||||
def forward(self, batch, state=None):
|
||||
return Batch(act=np.stack([1] * len(batch)))
|
||||
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
|
||||
def _finite_env_factory(dataset, num_replicas, rank):
|
||||
return lambda: FiniteEnv(dataset, num_replicas, rank)
|
||||
|
||||
|
||||
class MetricTracker:
|
||||
def __init__(self):
|
||||
self.counter = Counter()
|
||||
self.finished = set()
|
||||
|
||||
def log(self, obs, rew, done, info):
|
||||
assert rew == 1.
|
||||
index = info['sample']
|
||||
if done:
|
||||
assert index not in self.finished
|
||||
self.finished.add(index)
|
||||
self.counter[index] += 1
|
||||
|
||||
def validate(self):
|
||||
assert len(self.finished) == 100
|
||||
for k, v in self.counter.items():
|
||||
assert v == k * 3 % 5 + 1
|
||||
|
||||
|
||||
def test_finite_dummy_vector_env():
|
||||
dataset = DummyDataset(100)
|
||||
envs = FiniteSubprocVectorEnv([
|
||||
_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
for _ in range(3):
|
||||
envs.tracker = MetricTracker()
|
||||
try:
|
||||
test_collector.collect(n_step=10 ** 18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
|
||||
|
||||
def test_finite_subproc_vector_env():
|
||||
dataset = DummyDataset(100)
|
||||
envs = FiniteSubprocVectorEnv([
|
||||
_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
|
||||
for _ in range(3):
|
||||
envs.tracker = MetricTracker()
|
||||
try:
|
||||
test_collector.collect(n_step=10 ** 18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_finite_dummy_vector_env()
|
||||
test_finite_subproc_vector_env()
|
Loading…
x
Reference in New Issue
Block a user