import copy import pickle import numpy as np import pytest import torch from tianshou.data import Batch @pytest.fixture(scope="module") def data(): print("Initialising data...") np.random.seed(0) batch_set = [ Batch( a=[j for j in np.arange(1e3)], b={ 'b1': (3.14, 3.14), 'b2': np.arange(1e3) }, c=i ) for i in np.arange(int(1e4)) ] batch0 = Batch( a=np.ones((3, 4), dtype=np.float64), b=Batch( c=np.ones((1, ), dtype=np.float64), d=torch.ones((3, 3, 3), dtype=torch.float32), e=list(range(3)) ) ) batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] batch_len = int(1e4) batch3 = Batch( obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len) ) indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) slice_dict = { 'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], 'reward': np.arange(batch_len // 10) } dict_set = [ { 'obs': np.arange(20), 'info': "this is info", 'reward': 0 } for _ in np.arange(1e2) ] batch4 = Batch( a=np.ones((10000, 4), dtype=np.float64), b=Batch( c=np.ones((1, ), dtype=np.float64), d=torch.ones((1000, 1000), dtype=torch.float32), e=np.arange(1000) ) ) print("Initialised") return { 'batch_set': batch_set, 'batch0': batch0, 'batchs1': batchs1, 'batchs2': batchs2, 'batch3': batch3, 'indexs': indexs, 'dict_set': dict_set, 'slice_dict': slice_dict, 'batch4': batch4 } def test_init(data): """Test Batch __init__().""" for _ in np.arange(10): _ = Batch(data['batch_set']) def test_get_item(data): """Test get with item.""" for _ in np.arange(1e5): _ = data['batch3'][data['indexs']] def test_get_attr(data): """Test get with attr.""" for _ in np.arange(1e6): data['batch3'].get('obs') data['batch3'].get('reward') _, _ = data['batch3'].obs, data['batch3'].reward def test_set_item(data): """Test set with item.""" for _ in np.arange(1e4): data['batch3'][data['indexs']] = data['slice_dict'] def test_set_attr(data): """Test set with attr.""" for _ in np.arange(1e4): data['batch3'].c = np.arange(1e3) data['batch3'].obs = data['dict_set'] def test_numpy_torch_convert(data): """Test conversion between numpy and torch.""" for _ in np.arange(1e5): data['batch4'].to_torch() data['batch4'].to_numpy() def test_pickle(data): for _ in np.arange(1e4): pickle.loads(pickle.dumps(data['batch4'])) def test_cat(data): """Test cat""" for i in range(10000): Batch.cat((data['batch0'], data['batch0'])) data['batchs1'][i].cat_(data['batch0']) def test_stack(data): """Test stack""" for i in range(10000): Batch.stack((data['batch0'], data['batch0'])) data['batchs2'][i].stack_([data['batch0']]) if __name__ == '__main__': pytest.main(["-s", "-k batch_profile", "--durations=0", "-v"])