Tianshou/test/throughput/test_buffer_profile.py
Markus Krimmel 6c6c872523
Gymnasium Integration (#789)
Changes:
- Disclaimer in README
- Replaced all occurences of Gym with Gymnasium
- Removed code that is now dead since we no longer need to support the
old step API
- Updated type hints to only allow new step API
- Increased required version of envpool to support Gymnasium
- Increased required version of PettingZoo to support Gymnasium
- Updated `PettingZooEnv` to only use the new step API, removed hack to
also support old API
- I had to add some `# type: ignore` comments, due to new type hinting
in Gymnasium. I'm not that familiar with type hinting but I believe that
the issue is on the Gymnasium side and we are looking into it.
- Had to update `MyTestEnv` to support `options` kwarg
- Skip NNI tests because they still use OpenAI Gym
- Also allow `PettingZooEnv` in vector environment
- Updated doc page about ReplayBuffer to also talk about terminated and
truncated flags.

Still need to do: 
- Update the Jupyter notebooks in docs
- Check the entire code base for more dead code (from compatibility
stuff)
- Check the reset functions of all environments/wrappers in code base to
make sure they use the `options` kwarg
- Someone might want to check test_env_finite.py
- Is it okay to allow `PettingZooEnv` in vector environments? Might need
to update docs?
2023-02-03 11:57:27 -08:00

69 lines
2.2 KiB
Python

import sys
import time
import gymnasium as gym
import numpy as np
import tqdm
from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer
def test_replaybuffer(task="Pendulum-v1"):
total_count = 5
for _ in tqdm.trange(total_count, desc="ReplayBuffer"):
env = gym.make(task)
buf = ReplayBuffer(10000)
obs, info = env.reset()
for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, terminated, truncated, info = env.step(act)
done = terminated or truncated
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
terminated=np.array([terminated]),
truncated=np.array([truncated]),
done=np.array([done]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch, buffer_ids=[0])
obs = obs_next
if done:
obs, info = env.reset()
def test_vectorbuffer(task="Pendulum-v1"):
total_count = 5
for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"):
env = gym.make(task)
buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
obs, info = env.reset()
for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, terminated, truncated, info = env.step(act)
done = terminated or truncated
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
terminated=np.array([terminated]),
truncated=np.array([truncated]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch)
obs = obs_next
if done:
obs, info = env.reset()
if __name__ == '__main__':
t0 = time.time()
test_replaybuffer(sys.argv[-1])
print("test replaybuffer: ", time.time() - t0)
t0 = time.time()
test_vectorbuffer(sys.argv[-1])
print("test vectorbuffer: ", time.time() - t0)