Tianshou/test/throughput/test_collector_profile.py
youkaichao a9f9940d17
code refactor for venv (#179)
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <463003665@qq.com>
Co-authored-by: magicly <magicly007@gmail.com>
2020-08-19 15:00:24 +08:00

163 lines
4.3 KiB
Python

import gym
import numpy as np
import pytest
from gym.spaces.discrete import Discrete
from gym.utils import seeding
from tianshou.data import Batch, Collector, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
class SimpleEnv(gym.Env):
"""A simplest example of self-defined env, used to minimize
data collect time and profile collector."""
def __init__(self):
self.action_space = Discrete(200)
self._fake_data = np.ones((10, 10, 1))
self.seed(0)
self.reset()
def reset(self):
self._index = 0
self.done = np.random.randint(3, high=200)
return {'observable': np.zeros((10, 10, 1)),
'hidden': self._index}
def step(self, action):
if self._index == self.done:
raise ValueError('step after done !!!')
self._index += 1
return {'observable': self._fake_data, 'hidden': self._index}, -1, \
self._index == self.done, {}
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
class SimplePolicy(BasePolicy):
"""A simplest example of self-defined policy, used
to minimize data collect time."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def learn(self, batch, **kwargs):
return super().learn(batch, **kwargs)
def forward(self, batch, state=None, **kwargs):
return Batch(act=np.array([30] * len(batch)), state=None, logits=None)
@pytest.fixture(scope="module")
def data():
np.random.seed(0)
env = SimpleEnv()
env.seed(0)
env_vec = DummyVectorEnv(
[lambda: SimpleEnv() for _ in range(100)])
env_vec.seed(np.random.randint(1000, size=100).tolist())
env_subproc = SubprocVectorEnv(
[lambda: SimpleEnv() for _ in range(8)])
env_subproc.seed(np.random.randint(1000, size=100).tolist())
env_subproc_init = SubprocVectorEnv(
[lambda: SimpleEnv() for _ in range(8)])
env_subproc_init.seed(np.random.randint(1000, size=100).tolist())
buffer = ReplayBuffer(50000)
policy = SimplePolicy()
collector = Collector(policy, env, ReplayBuffer(50000))
collector_vec = Collector(policy, env_vec, ReplayBuffer(50000))
collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000))
return {
"env": env,
"env_vec": env_vec,
"env_subproc": env_subproc,
"env_subproc_init": env_subproc_init,
"policy": policy,
"buffer": buffer,
"collector": collector,
"collector_vec": collector_vec,
"collector_subproc": collector_subproc,
}
def test_init(data):
for _ in range(5000):
Collector(data["policy"], data["env"], data["buffer"])
def test_reset(data):
for _ in range(5000):
data["collector"].reset()
def test_collect_st(data):
for _ in range(50):
data["collector"].collect(n_step=1000)
def test_collect_ep(data):
for _ in range(50):
data["collector"].collect(n_episode=10)
def test_sample(data):
for _ in range(5000):
data["collector"].sample(256)
def test_init_vec_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_vec"], data["buffer"])
def test_reset_vec_env(data):
for _ in range(5000):
data["collector_vec"].reset()
def test_collect_vec_env_st(data):
for _ in range(50):
data["collector_vec"].collect(n_step=1000)
def test_collect_vec_env_ep(data):
for _ in range(50):
data["collector_vec"].collect(n_episode=10)
def test_sample_vec_env(data):
for _ in range(5000):
data["collector_vec"].sample(256)
def test_init_subproc_env(data):
for _ in range(5000):
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
def test_reset_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].reset()
def test_collect_subproc_env_st(data):
for _ in range(50):
data["collector_subproc"].collect(n_step=1000)
def test_collect_subproc_env_ep(data):
for _ in range(50):
data["collector_subproc"].collect(n_episode=10)
def test_sample_subproc_env(data):
for _ in range(5000):
data["collector_subproc"].sample(256)
if __name__ == '__main__':
pytest.main(["-s", "-k collector_profile", "--durations=0", "-v"])