Tianshou/test/throughput/test_collector_profile.py
n+e b284ace102
type check in unit test (#200)
Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml.

Also remove the out-of-the-date api
2020-09-13 19:31:50 +08:00

145 lines
4.0 KiB
Python

import gym
import numpy as np
import pytest
from gym.spaces.discrete import Discrete
from gym.utils import seeding
from tianshou.data import Batch, Collector, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
class SimpleEnv(gym.Env):
"""A simplest example of self-defined env, used to minimize
data collect time and profile collector."""
def __init__(self):
self.action_space = Discrete(200)
self._fake_data = np.ones((10, 10, 1))
self.seed(0)
self.reset()
def reset(self):
self._index = 0
self.done = np.random.randint(3, high=200)
return {'observable': np.zeros((10, 10, 1)), 'hidden': self._index}
def step(self, action):
if self._index == self.done:
raise ValueError('step after done !!!')
self._index += 1
return {'observable': self._fake_data, 'hidden': self._index}, -1, \
self._index == self.done, {}
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
class SimplePolicy(BasePolicy):
"""A simplest example of self-defined policy, used
to minimize data collect time."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def learn(self, batch, **kwargs):
return super().learn(batch, **kwargs)
def forward(self, batch, state=None, **kwargs):
return Batch(act=np.array([30] * len(batch)), state=None, logits=None)
@pytest.fixture(scope="module")
def data():
np.random.seed(0)
env = SimpleEnv()
env.seed(0)
env_vec = DummyVectorEnv([lambda: SimpleEnv() for _ in range(100)])
env_vec.seed(np.random.randint(1000, size=100).tolist())
env_subproc = SubprocVectorEnv([lambda: SimpleEnv() for _ in range(8)])
env_subproc.seed(np.random.randint(1000, size=100).tolist())
env_subproc_init = SubprocVectorEnv(
[lambda: SimpleEnv() for _ in range(8)])
env_subproc_init.seed(np.random.randint(1000, size=100).tolist())
buffer = ReplayBuffer(50000)
policy = SimplePolicy()
collector = Collector(policy, env, ReplayBuffer(50000))
collector_vec = Collector(policy, env_vec, ReplayBuffer(50000))
collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000))
return {
"env": env,
"env_vec": env_vec,
"env_subproc": env_subproc,
"env_subproc_init": env_subproc_init,
"policy": policy,
"buffer": buffer,
"collector": collector,
"collector_vec": collector_vec,
"collector_subproc": collector_subproc,
}
def test_init(data):
for _ in range(5000):
Collector(data["policy"], data["env"], data["buffer"])
def test_reset(data):
for _ in range(5000):
data["collector"].reset()
def test_collect_st(data):
for _ in range(50):
data["collector"].collect(n_step=1000)
def test_collect_ep(data):
for _ in range(50):
data["collector"].collect(n_episode=10)
def test_init_vec_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"])
def test_reset_vec_env(data):
for _ in range(5000):
data["collector_vec"].reset()
def test_collect_vec_env_st(data):
for _ in range(50):
data["collector_vec"].collect(n_step=1000)
def test_collect_vec_env_ep(data):
for _ in range(50):
data["collector_vec"].collect(n_episode=10)
def test_init_subproc_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
def test_reset_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].reset()
def test_collect_subproc_env_st(data):
for _ in range(50):
data["collector_subproc"].collect(n_step=1000)
def test_collect_subproc_env_ep(data):
for _ in range(50):
data["collector_subproc"].collect(n_episode=10)
if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])