1. add policy.eval() in all test scripts' "watch performance" 2. remove dict return support for collector preprocess_fn 3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)` 4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184) 5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard 6. add test_returns (both GAE and nstep) 7. change the type-checking order in batch.py and converter.py in order to meet the most often case first 8. fix shape inconsistency for torch.Tensor in replay buffer 9. remove `**kwargs` in ReplayBuffer 10. remove default value in batch.split() and add merge_last argument (#185) 11. improve nstep efficiency 12. add max_batchsize in onpolicy algorithms 13. potential bugfix for subproc.wait 14. fix RecurrentActorProb 15. improve the code-coverage (from 90% to 95%) and remove the dead code 16. fix some incorrect type annotation The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
160 lines
4.3 KiB
Python
160 lines
4.3 KiB
Python
import gym
|
|
import numpy as np
|
|
import pytest
|
|
from gym.spaces.discrete import Discrete
|
|
from gym.utils import seeding
|
|
|
|
from tianshou.data import Batch, Collector, ReplayBuffer
|
|
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
|
from tianshou.policy import BasePolicy
|
|
|
|
|
|
class SimpleEnv(gym.Env):
|
|
"""A simplest example of self-defined env, used to minimize
|
|
data collect time and profile collector."""
|
|
|
|
def __init__(self):
|
|
self.action_space = Discrete(200)
|
|
self._fake_data = np.ones((10, 10, 1))
|
|
self.seed(0)
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self._index = 0
|
|
self.done = np.random.randint(3, high=200)
|
|
return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index}
|
|
|
|
def step(self, action):
|
|
if self._index == self.done:
|
|
raise ValueError('step after done !!!')
|
|
self._index += 1
|
|
return {'observable': self._fake_data, 'hidden': self._index}, -1, \
|
|
self._index == self.done, {}
|
|
|
|
def seed(self, seed=None):
|
|
self.np_random, seed = seeding.np_random(seed)
|
|
return [seed]
|
|
|
|
|
|
class SimplePolicy(BasePolicy):
|
|
"""A simplest example of self-defined policy, used
|
|
to minimize data collect time."""
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def learn(self, batch, **kwargs):
|
|
return super().learn(batch, **kwargs)
|
|
|
|
def forward(self, batch, state=None, **kwargs):
|
|
return Batch(act=np.array([30] * len(batch)), state=None, logits=None)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def data():
|
|
np.random.seed(0)
|
|
env = SimpleEnv()
|
|
env.seed(0)
|
|
env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)])
|
|
env_vec.seed(np.random.randint(1000, size=100).tolist())
|
|
env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)])
|
|
env_subproc.seed(np.random.randint(1000, size=100).tolist())
|
|
env_subproc_init = SubprocVectorEnv(
|
|
[lambda: SimpleEnv() for _ in range(8)])
|
|
env_subproc_init.seed(np.random.randint(1000, size=100).tolist())
|
|
buffer = ReplayBuffer(50000)
|
|
policy = SimplePolicy()
|
|
collector = Collector(policy, env, ReplayBuffer(50000))
|
|
collector_vec = Collector(policy, env_vec, ReplayBuffer(50000))
|
|
collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000))
|
|
return {
|
|
"env": env,
|
|
"env_vec": env_vec,
|
|
"env_subproc": env_subproc,
|
|
"env_subproc_init": env_subproc_init,
|
|
"policy": policy,
|
|
"buffer": buffer,
|
|
"collector": collector,
|
|
"collector_vec": collector_vec,
|
|
"collector_subproc": collector_subproc,
|
|
}
|
|
|
|
|
|
def test_init(data):
|
|
for _ in range(5000):
|
|
Collector(data["policy"], data["env"], data["buffer"])
|
|
|
|
|
|
def test_reset(data):
|
|
for _ in range(5000):
|
|
data["collector"].reset()
|
|
|
|
|
|
def test_collect_st(data):
|
|
for _ in range(50):
|
|
data["collector"].collect(n_step=1000)
|
|
|
|
|
|
def test_collect_ep(data):
|
|
for _ in range(50):
|
|
data["collector"].collect(n_episode=10)
|
|
|
|
|
|
def test_sample(data):
|
|
for _ in range(5000):
|
|
data["collector"].sample(256)
|
|
|
|
|
|
def test_init_vec_env(data):
|
|
for _ in range(5000):
|
|
Collector(data["policy"], data["env_vec"], data["buffer"])
|
|
|
|
|
|
def test_reset_vec_env(data):
|
|
for _ in range(5000):
|
|
data["collector_vec"].reset()
|
|
|
|
|
|
def test_collect_vec_env_st(data):
|
|
for _ in range(50):
|
|
data["collector_vec"].collect(n_step=1000)
|
|
|
|
|
|
def test_collect_vec_env_ep(data):
|
|
for _ in range(50):
|
|
data["collector_vec"].collect(n_episode=10)
|
|
|
|
|
|
def test_sample_vec_env(data):
|
|
for _ in range(5000):
|
|
data["collector_vec"].sample(256)
|
|
|
|
|
|
def test_init_subproc_env(data):
|
|
for _ in range(5000):
|
|
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
|
|
|
|
|
def test_reset_subproc_env(data):
|
|
for _ in range(5000):
|
|
data["collector_subproc"].reset()
|
|
|
|
|
|
def test_collect_subproc_env_st(data):
|
|
for _ in range(50):
|
|
data["collector_subproc"].collect(n_step=1000)
|
|
|
|
|
|
def test_collect_subproc_env_ep(data):
|
|
for _ in range(50):
|
|
data["collector_subproc"].collect(n_episode=10)
|
|
|
|
|
|
def test_sample_subproc_env(data):
|
|
for _ in range(5000):
|
|
data["collector_subproc"].sample(256)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])
|