| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  | import copy | 
					
						
							|  |  |  | import pickle | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.fixture(scope="module") | 
					
						
							|  |  |  | def data(): | 
					
						
							|  |  |  |     print("Initialising data...") | 
					
						
							|  |  |  |     np.random.seed(0) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch_set = [ | 
					
						
							|  |  |  |         Batch( | 
					
						
							|  |  |  |             a=[j for j in np.arange(1e3)], | 
					
						
							|  |  |  |             b={ | 
					
						
							|  |  |  |                 'b1': (3.14, 3.14), | 
					
						
							|  |  |  |                 'b2': np.arange(1e3) | 
					
						
							|  |  |  |             }, | 
					
						
							|  |  |  |             c=i | 
					
						
							|  |  |  |         ) for i in np.arange(int(1e4)) | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  |     batch0 = Batch( | 
					
						
							|  |  |  |         a=np.ones((3, 4), dtype=np.float64), | 
					
						
							|  |  |  |         b=Batch( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             c=np.ones((1, ), dtype=np.float64), | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  |             d=torch.ones((3, 3, 3), dtype=torch.float32), | 
					
						
							|  |  |  |             e=list(range(3)) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     batchs1 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] | 
					
						
							|  |  |  |     batchs2 = [copy.deepcopy(batch0) for _ in np.arange(1e4)] | 
					
						
							|  |  |  |     batch_len = int(1e4) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch3 = Batch( | 
					
						
							|  |  |  |         obs=[np.arange(20) for _ in np.arange(batch_len)], reward=np.arange(batch_len) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     indexs = np.random.choice(batch_len, size=batch_len // 10, replace=False) | 
					
						
							|  |  |  |     slice_dict = { | 
					
						
							|  |  |  |         'obs': [np.arange(20) for _ in np.arange(batch_len // 10)], | 
					
						
							|  |  |  |         'reward': np.arange(batch_len // 10) | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     dict_set = [ | 
					
						
							|  |  |  |         { | 
					
						
							|  |  |  |             'obs': np.arange(20), | 
					
						
							|  |  |  |             'info': "this is info", | 
					
						
							|  |  |  |             'reward': 0 | 
					
						
							|  |  |  |         } for _ in np.arange(1e2) | 
					
						
							|  |  |  |     ] | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  |     batch4 = Batch( | 
					
						
							|  |  |  |         a=np.ones((10000, 4), dtype=np.float64), | 
					
						
							|  |  |  |         b=Batch( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             c=np.ones((1, ), dtype=np.float64), | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  |             d=torch.ones((1000, 1000), dtype=torch.float32), | 
					
						
							|  |  |  |             e=np.arange(1000) | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("Initialised") | 
					
						
							| 
									
										
										
										
											2020-08-19 15:00:24 +08:00
										 |  |  |     return { | 
					
						
							|  |  |  |         'batch_set': batch_set, | 
					
						
							|  |  |  |         'batch0': batch0, | 
					
						
							|  |  |  |         'batchs1': batchs1, | 
					
						
							|  |  |  |         'batchs2': batchs2, | 
					
						
							|  |  |  |         'batch3': batch3, | 
					
						
							|  |  |  |         'indexs': indexs, | 
					
						
							|  |  |  |         'dict_set': dict_set, | 
					
						
							|  |  |  |         'slice_dict': slice_dict, | 
					
						
							|  |  |  |         'batch4': batch4 | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2020-08-02 18:24:40 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_init(data): | 
					
						
							|  |  |  |     """Test Batch __init__().""" | 
					
						
							|  |  |  |     for _ in np.arange(10): | 
					
						
							|  |  |  |         _ = Batch(data['batch_set']) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_get_item(data): | 
					
						
							|  |  |  |     """Test get with item.""" | 
					
						
							|  |  |  |     for _ in np.arange(1e5): | 
					
						
							|  |  |  |         _ = data['batch3'][data['indexs']] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_get_attr(data): | 
					
						
							|  |  |  |     """Test get with attr.""" | 
					
						
							|  |  |  |     for _ in np.arange(1e6): | 
					
						
							|  |  |  |         data['batch3'].get('obs') | 
					
						
							|  |  |  |         data['batch3'].get('reward') | 
					
						
							|  |  |  |         _, _ = data['batch3'].obs, data['batch3'].reward | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_set_item(data): | 
					
						
							|  |  |  |     """Test set with item.""" | 
					
						
							|  |  |  |     for _ in np.arange(1e4): | 
					
						
							|  |  |  |         data['batch3'][data['indexs']] = data['slice_dict'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_set_attr(data): | 
					
						
							|  |  |  |     """Test set with attr.""" | 
					
						
							|  |  |  |     for _ in np.arange(1e4): | 
					
						
							|  |  |  |         data['batch3'].c = np.arange(1e3) | 
					
						
							|  |  |  |         data['batch3'].obs = data['dict_set'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_numpy_torch_convert(data): | 
					
						
							|  |  |  |     """Test conversion between numpy and torch.""" | 
					
						
							|  |  |  |     for _ in np.arange(1e5): | 
					
						
							|  |  |  |         data['batch4'].to_torch() | 
					
						
							|  |  |  |         data['batch4'].to_numpy() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_pickle(data): | 
					
						
							|  |  |  |     for _ in np.arange(1e4): | 
					
						
							|  |  |  |         pickle.loads(pickle.dumps(data['batch4'])) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_cat(data): | 
					
						
							|  |  |  |     """Test cat""" | 
					
						
							|  |  |  |     for i in range(10000): | 
					
						
							|  |  |  |         Batch.cat((data['batch0'], data['batch0'])) | 
					
						
							|  |  |  |         data['batchs1'][i].cat_(data['batch0']) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_stack(data): | 
					
						
							|  |  |  |     """Test stack""" | 
					
						
							|  |  |  |     for i in range(10000): | 
					
						
							|  |  |  |         Batch.stack((data['batch0'], data['batch0'])) | 
					
						
							|  |  |  |         data['batchs2'][i].stack_([data['batch0']]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     pytest.main(["-s", "-k batch_profile", "--durations=0", "-v"]) |