diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index fc9af2f..8e67963 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -30,7 +30,7 @@ jobs: - name: Lint with flake8 run: | pip install flake8 - ./flake_check.sh + ./flake8.sh - name: Test with pytest run: | pip install pytest diff --git a/flake_check.sh b/flake8.sh similarity index 100% rename from flake_check.sh rename to flake8.sh diff --git a/test/test_env.py b/test/test_env.py new file mode 100644 index 0000000..e56afff --- /dev/null +++ b/test/test_env.py @@ -0,0 +1,57 @@ +import gym +import time +import numpy as np +from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv + + +class MyTestEnv(gym.Env): + def __init__(self, size, sleep=0): + self.size = size + self.sleep = sleep + self.index = 0 + + def reset(self): + self.index = 0 + return self.index + + def step(self, action): + if self.sleep > 0: + time.sleep(self.sleep) + if self.index == self.size: + return self.index, 0, True, {} + if action == 0: + self.index = max(self.index - 1, 0) + return self.index, 0, False, {} + elif action == 1: + self.index += 1 + finished = self.index == self.size + return self.index, int(finished), finished, {} + +def test_framestack(): + k = 4 + size = 10 + env = MyTestEnv(size=size) + fsenv = FrameStack(env, k) + fsenv.seed() + obs = fsenv.reset() + assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 + for i in range(5): + obs, rew, done, info = fsenv.step(1) + assert abs(obs - np.array([2, 3, 4, 5])).sum() == 0 + for i in range(10): + obs, rew, done, info = fsenv.step(0) + assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0 + for i in range(9): + obs, rew, done, info = fsenv.step(1) + assert abs(obs - np.array([6, 7, 8, 9])).sum() == 0 + assert (rew, done) == (0, False) + obs, rew, done, info = fsenv.step(1) + assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0 + assert (rew, done) == (1, True) + obs, rew, done, info = fsenv.step(0) + assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 + assert (rew, done) == (0, True) + fsenv.close() + +if __name__ == '__main__': + test_framestack() \ No newline at end of file diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index c1b0c10..ca81f97 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -50,7 +50,10 @@ class FrameStack(EnvWrapper): return self._get_obs() def _get_obs(self): - return np.concatenate(self._frames, axis=-1) + try: + return np.concatenate(self._frames, axis=-1) + except ValueError: + return np.stack(self._frames, axis=-1) class VectorEnv(object): @@ -177,11 +180,10 @@ class RayVectorEnv(object): self.env_num = len(env_fns) self._reset_after_done = kwargs.get('reset_after_done', False) try: - import ray - except ImportError: + if not ray.is_initialized(): + ray.init() + except NameError: raise ImportError('Please install ray to support VectorEnv: pip3 install ray -U') - if not ray.is_initialized(): - ray.init() self.envs = [ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] def __len__(self): diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 128ba80..266c2b2 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,3 +1,4 @@ from tianshou.utils.cloudpicklewrapper import CloudpickleWrapper +from tianshou.utils.config import tqdm_config -__all__ = ['CloudpickleWrapper'] +__all__ = ['CloudpickleWrapper', 'tqdm_config'] diff --git a/tianshou/utils/config.py b/tianshou/utils/config.py new file mode 100644 index 0000000..4cf8503 --- /dev/null +++ b/tianshou/utils/config.py @@ -0,0 +1,4 @@ +tqdm_config = { + 'dynamic_ncols': True, + 'ascii': True, +}