Tianshou/test/throughput/test_buffer_profile.py
Markus Krimmel ea36dc5195
Changes to support Gym 0.26.0 (#748)
* Changes to support Gym 0.26.0

* Replace map by simpler list comprehension

* Use syntax that is compatible with python 3.7

* Format code

* Fix environment seeding in test environment, fix buffer_profile test

* Remove self.seed() from __init__

* Fix random number generation

* Fix throughput tests

* Fix tests

* Removed done field from Buffer, fixed throughput test, turned off wandb, fixed formatting, fixed type hints, allow preprocessing_fn with truncated and terminated arguments, updated docstrings

* fix lint

* fix

* fix import

* fix

* fix mypy

* pytest --ignore='test/3rd_party'

* Use correct step API in _SetAttrWrapper

* Format

* Fix mypy

* Format

* Fix pydocstyle.
2022-09-26 09:31:23 -07:00

69 lines
2.1 KiB
Python

import sys
import time
import gym
import numpy as np
import tqdm
from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer
def test_replaybuffer(task="Pendulum-v1"):
total_count = 5
for _ in tqdm.trange(total_count, desc="ReplayBuffer"):
env = gym.make(task)
buf = ReplayBuffer(10000)
obs, info = env.reset()
for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, terminated, truncated, info = env.step(act)
done = terminated or truncated
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
terminated=np.array([terminated]),
truncated=np.array([truncated]),
done=np.array([done]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch, buffer_ids=[0])
obs = obs_next
if done:
obs, info = env.reset()
def test_vectorbuffer(task="Pendulum-v1"):
total_count = 5
for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"):
env = gym.make(task)
buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
obs, info = env.reset()
for _ in range(100000):
act = env.action_space.sample()
obs_next, rew, terminated, truncated, info = env.step(act)
done = terminated or truncated
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
terminated=np.array([terminated]),
truncated=np.array([truncated]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch)
obs = obs_next
if done:
obs, info = env.reset()
if __name__ == '__main__':
t0 = time.time()
test_replaybuffer(sys.argv[-1])
print("test replaybuffer: ", time.time() - t0)
t0 = time.time()
test_vectorbuffer(sys.argv[-1])
print("test vectorbuffer: ", time.time() - t0)