- Refacor code to remove duplicate code - Enable async simulation for all vector envs - Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv` The abstraction of vector env changed. Prior to this pr, each vector env is almost independent. After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility. Co-authored-by: n+e <463003665@qq.com> Co-authored-by: magicly <magicly007@gmail.com>
122 lines
3.3 KiB
Python
122 lines
3.3 KiB
Python
import copy
|
|
import pickle
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from tianshou.data import Batch
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def data():
|
|
print("Initialising data...")
|
|
np.random.seed(0)
|
|
batch_set = [Batch(a=[j for j in np.arange(1e3)],
|
|
b={'b1': (3.14, 3.14), 'b2': np.arange(1e3)},
|
|
c=i) for i in np.arange(int(1e4))]
|
|
batch0 = Batch(
|
|
a=np.ones((3, 4), dtype=np.float64),
|
|
b=Batch(
|
|
c=np.ones((1,), dtype=np.float64),
|
|
d=torch.ones((3, 3, 3), dtype=torch.float32),
|
|
e=list(range(3))
|
|
)
|
|
)
|
|
batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
|
batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)]
|
|
batch_len = int(1e4)
|
|
batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)],
|
|
reward=np.arange(batch_len))
|
|
indexs = np.random.choice(batch_len,
|
|
size=batch_len // 10, replace=False)
|
|
slice_dict = {'obs': [np.arange(20)
|
|
for _ in np.arange(batch_len // 10)],
|
|
'reward': np.arange(batch_len // 10)}
|
|
dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0}
|
|
for _ in np.arange(1e2)]
|
|
batch4 = Batch(
|
|
a=np.ones((10000, 4), dtype=np.float64),
|
|
b=Batch(
|
|
c=np.ones((1,), dtype=np.float64),
|
|
d=torch.ones((1000, 1000), dtype=torch.float32),
|
|
e=np.arange(1000)
|
|
)
|
|
)
|
|
|
|
print("Initialised")
|
|
return {
|
|
'batch_set': batch_set,
|
|
'batch0': batch0,
|
|
'batchs1': batchs1,
|
|
'batchs2': batchs2,
|
|
'batch3': batch3,
|
|
'indexs': indexs,
|
|
'dict_set': dict_set,
|
|
'slice_dict': slice_dict,
|
|
'batch4': batch4
|
|
}
|
|
|
|
|
|
def test_init(data):
|
|
"""Test Batch __init__()."""
|
|
for _ in np.arange(10):
|
|
_ = Batch(data['batch_set'])
|
|
|
|
|
|
def test_get_item(data):
|
|
"""Test get with item."""
|
|
for _ in np.arange(1e5):
|
|
_ = data['batch3'][data['indexs']]
|
|
|
|
|
|
def test_get_attr(data):
|
|
"""Test get with attr."""
|
|
for _ in np.arange(1e6):
|
|
data['batch3'].get('obs')
|
|
data['batch3'].get('reward')
|
|
_, _ = data['batch3'].obs, data['batch3'].reward
|
|
|
|
|
|
def test_set_item(data):
|
|
"""Test set with item."""
|
|
for _ in np.arange(1e4):
|
|
data['batch3'][data['indexs']] = data['slice_dict']
|
|
|
|
|
|
def test_set_attr(data):
|
|
"""Test set with attr."""
|
|
for _ in np.arange(1e4):
|
|
data['batch3'].c = np.arange(1e3)
|
|
data['batch3'].obs = data['dict_set']
|
|
|
|
|
|
def test_numpy_torch_convert(data):
|
|
"""Test conversion between numpy and torch."""
|
|
for _ in np.arange(1e5):
|
|
data['batch4'].to_torch()
|
|
data['batch4'].to_numpy()
|
|
|
|
|
|
def test_pickle(data):
|
|
for _ in np.arange(1e4):
|
|
pickle.loads(pickle.dumps(data['batch4']))
|
|
|
|
|
|
def test_cat(data):
|
|
"""Test cat"""
|
|
for i in range(10000):
|
|
Batch.cat((data['batch0'], data['batch0']))
|
|
data['batchs1'][i].cat_(data['batch0'])
|
|
|
|
|
|
def test_stack(data):
|
|
"""Test stack"""
|
|
for i in range(10000):
|
|
Batch.stack((data['batch0'], data['batch0']))
|
|
data['batchs2'][i].stack_([data['batch0']])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main(["-s", "-k batch_profile", "--durations=0", "-v"])
|