Tianshou/test/throughput/test_buffer_profile.py
ChenDRAG 150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00

62 lines
1.8 KiB
Python

import sys
import gym
import time
import tqdm
import numpy as np
from tianshou.data import Batch, ReplayBuffer, VectorReplayBuffer
def test_replaybuffer(task="Pendulum-v0"):
total_count = 5
for _ in tqdm.trange(total_count, desc="ReplayBuffer"):
env = gym.make(task)
buf = ReplayBuffer(10000)
obs = env.reset()
for i in range(100000):
act = env.action_space.sample()
obs_next, rew, done, info = env.step(act)
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
done=np.array([done]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch, buffer_ids=[0])
obs = obs_next
if done:
obs = env.reset()
def test_vectorbuffer(task="Pendulum-v0"):
total_count = 5
for _ in tqdm.trange(total_count, desc="VectorReplayBuffer"):
env = gym.make(task)
buf = VectorReplayBuffer(total_size=10000, buffer_num=1)
obs = env.reset()
for i in range(100000):
act = env.action_space.sample()
obs_next, rew, done, info = env.step(act)
batch = Batch(
obs=np.array([obs]),
act=np.array([act]),
rew=np.array([rew]),
done=np.array([done]),
obs_next=np.array([obs_next]),
info=np.array([info]),
)
buf.add(batch)
obs = obs_next
if done:
obs = env.reset()
if __name__ == '__main__':
t0 = time.time()
test_replaybuffer(sys.argv[-1])
print("test replaybuffer: ", time.time() - t0)
t0 = time.time()
test_vectorbuffer(sys.argv[-1])
print("test vectorbuffer: ", time.time() - t0)