Changes: - Disclaimer in README - Replaced all occurences of Gym with Gymnasium - Removed code that is now dead since we no longer need to support the old step API - Updated type hints to only allow new step API - Increased required version of envpool to support Gymnasium - Increased required version of PettingZoo to support Gymnasium - Updated `PettingZooEnv` to only use the new step API, removed hack to also support old API - I had to add some `# type: ignore` comments, due to new type hinting in Gymnasium. I'm not that familiar with type hinting but I believe that the issue is on the Gymnasium side and we are looking into it. - Had to update `MyTestEnv` to support `options` kwarg - Skip NNI tests because they still use OpenAI Gym - Also allow `PettingZooEnv` in vector environment - Updated doc page about ReplayBuffer to also talk about terminated and truncated flags. Still need to do: - Update the Jupyter notebooks in docs - Check the entire code base for more dead code (from compatibility stuff) - Check the reset functions of all environments/wrappers in code base to make sure they use the `options` kwarg - Someone might want to check test_env_finite.py - Is it okay to allow `PettingZooEnv` in vector environments? Might need to update docs?
69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
import sys
|
|
import time
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import tqdm
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer
|
|
|
|
|
|
def test_replaybuffer(task="Pendulum-v1"):
|
|
total_count = 5
|
|
for _ in tqdm.trange(total_count, desc="ReplayBuffer"):
|
|
env = gym.make(task)
|
|
buf = ReplayBuffer(10000)
|
|
obs, info = env.reset()
|
|
for _ in range(100000):
|
|
act = env.action_space.sample()
|
|
obs_next, rew, terminated, truncated, info = env.step(act)
|
|
done = terminated or truncated
|
|
batch = Batch(
|
|
obs=np.array([obs]),
|
|
act=np.array([act]),
|
|
rew=np.array([rew]),
|
|
terminated=np.array([terminated]),
|
|
truncated=np.array([truncated]),
|
|
done=np.array([done]),
|
|
obs_next=np.array([obs_next]),
|
|
info=np.array([info]),
|
|
)
|
|
buf.add(batch, buffer_ids=[0])
|
|
obs = obs_next
|
|
if done:
|
|
obs, info = env.reset()
|
|
|
|
|
|
def test_vectorbuffer(task="Pendulum-v1"):
|
|
total_count = 5
|
|
for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"):
|
|
env = gym.make(task)
|
|
buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
|
|
obs, info = env.reset()
|
|
for _ in range(100000):
|
|
act = env.action_space.sample()
|
|
obs_next, rew, terminated, truncated, info = env.step(act)
|
|
done = terminated or truncated
|
|
batch = Batch(
|
|
obs=np.array([obs]),
|
|
act=np.array([act]),
|
|
rew=np.array([rew]),
|
|
terminated=np.array([terminated]),
|
|
truncated=np.array([truncated]),
|
|
obs_next=np.array([obs_next]),
|
|
info=np.array([info]),
|
|
)
|
|
buf.add(batch)
|
|
obs = obs_next
|
|
if done:
|
|
obs, info = env.reset()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
t0 = time.time()
|
|
test_replaybuffer(sys.argv[-1])
|
|
print("test replaybuffer: ", time.time() - t0)
|
|
t0 = time.time()
|
|
test_vectorbuffer(sys.argv[-1])
|
|
print("test vectorbuffer: ", time.time() - t0)
|