| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  | import copy | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  | import pickle | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  | from itertools import starmap | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  | import networkx as nx | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import pytest | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from tianshou.data import Batch, to_numpy, to_torch | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_batch(): | 
					
						
							| 
									
										
										
										
											2020-07-08 21:00:00 +08:00
										 |  |  |     assert list(Batch()) == [] | 
					
						
							| 
									
										
										
										
											2020-07-11 09:44:47 +08:00
										 |  |  |     assert Batch().is_empty() | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     assert not Batch(b={'c': {}}).is_empty() | 
					
						
							|  |  |  |     assert Batch(b={'c': {}}).is_empty(recurse=True) | 
					
						
							|  |  |  |     assert not Batch(a=Batch(), b=Batch(c=Batch())).is_empty() | 
					
						
							|  |  |  |     assert Batch(a=Batch(), b=Batch(c=Batch())).is_empty(recurse=True) | 
					
						
							|  |  |  |     assert not Batch(d=1).is_empty() | 
					
						
							|  |  |  |     assert not Batch(a=np.float64(1.0)).is_empty() | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3 | 
					
						
							| 
									
										
										
										
											2020-07-11 09:44:47 +08:00
										 |  |  |     assert not Batch(a=[1, 2, 3]).is_empty() | 
					
						
							| 
									
										
										
											
												fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
											
										 
											2021-02-02 19:28:05 +08:00
										 |  |  |     b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None]) | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert b.c.dtype == object | 
					
						
							| 
									
										
										
										
											2021-02-19 10:33:49 +08:00
										 |  |  |     b = Batch(d=[None], e=[starmap], f=Batch) | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert b.d.dtype == b.e.dtype == object and b.f == Batch | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     b = Batch() | 
					
						
							|  |  |  |     b.update() | 
					
						
							|  |  |  |     assert b.is_empty() | 
					
						
							|  |  |  |     b.update(c=[3, 5]) | 
					
						
							|  |  |  |     assert np.allclose(b.c, [3, 5]) | 
					
						
							|  |  |  |     # mimic the behavior of dict.update, where kwargs can overwrite keys | 
					
						
							|  |  |  |     b.update({'a': 2}, a=3) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert 'a' in b and b.a == 3 | 
					
						
							|  |  |  |     assert b.pop('a') == 3 | 
					
						
							|  |  |  |     assert 'a' not in b | 
					
						
							| 
									
										
										
										
											2020-07-11 09:44:47 +08:00
										 |  |  |     with pytest.raises(AssertionError): | 
					
						
							|  |  |  |         Batch({1: 2}) | 
					
						
							| 
									
										
										
										
											2021-04-25 15:23:46 +08:00
										 |  |  |     assert Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))]).a.dtype == object | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))]) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))]) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))]) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))]) | 
					
						
							| 
									
										
										
										
											2020-07-08 21:00:00 +08:00
										 |  |  |     batch = Batch(a=[torch.ones(3), torch.ones(3)]) | 
					
						
							|  |  |  |     assert torch.allclose(batch.a, torch.ones(2, 3)) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     batch.cat_(batch) | 
					
						
							|  |  |  |     assert torch.allclose(batch.a, torch.ones(4, 3)) | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     Batch(a=[]) | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |     batch = Batch(obs=[0], np=np.zeros([3, 4])) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     assert batch.obs == batch["obs"] | 
					
						
							| 
									
										
										
										
											2020-04-03 21:28:12 +08:00
										 |  |  |     batch.obs = [1] | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |     assert batch.obs == [1] | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     batch.cat_(batch) | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert np.allclose(batch.obs, [1, 1]) | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  |     assert batch.np.shape == (6, 4) | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert np.allclose(batch[0].obs, batch[1].obs) | 
					
						
							| 
									
										
										
										
											2020-03-17 11:37:31 +08:00
										 |  |  |     batch.obs = np.arange(5) | 
					
						
							| 
									
										
										
										
											2020-04-28 20:56:02 +08:00
										 |  |  |     for i, b in enumerate(batch.split(1, shuffle=False)): | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |         if i != 5: | 
					
						
							|  |  |  |             assert b.obs == batch[i].obs | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             with pytest.raises(AttributeError): | 
					
						
							|  |  |  |                 batch[i].obs | 
					
						
							|  |  |  |             with pytest.raises(AttributeError): | 
					
						
							|  |  |  |                 b.obs | 
					
						
							| 
									
										
										
										
											2020-04-09 19:53:45 +08:00
										 |  |  |     print(batch) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     batch = Batch(a=np.arange(10)) | 
					
						
							|  |  |  |     with pytest.raises(AssertionError): | 
					
						
							|  |  |  |         list(batch.split(0)) | 
					
						
							|  |  |  |     data = [ | 
					
						
							|  |  |  |         (1, False, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), | 
					
						
							|  |  |  |         (1, True, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]), | 
					
						
							|  |  |  |         (3, False, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]), | 
					
						
							|  |  |  |         (3, True, [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (5, False, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (5, True, [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (7, False, [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]), | 
					
						
							|  |  |  |         (7, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (10, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (10, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (15, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (15, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (100, False, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |         (100, True, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]), | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  |     for size, merge_last, result in data: | 
					
						
							|  |  |  |         bs = list(batch.split(size, shuffle=False, merge_last=merge_last)) | 
					
						
							|  |  |  |         assert [bs[i].a.tolist() for i in range(len(bs))] == result | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     batch_dict = {'b': np.array([1.0]), 'c': 2.0, 'd': torch.Tensor([3.0])} | 
					
						
							|  |  |  |     batch_item = Batch({'a': [batch_dict]})[0] | 
					
						
							|  |  |  |     assert isinstance(batch_item.a.b, np.ndarray) | 
					
						
							|  |  |  |     assert batch_item.a.b == batch_dict['b'] | 
					
						
							|  |  |  |     assert isinstance(batch_item.a.c, float) | 
					
						
							|  |  |  |     assert batch_item.a.c == batch_dict['c'] | 
					
						
							|  |  |  |     assert isinstance(batch_item.a.d, torch.Tensor) | 
					
						
							|  |  |  |     assert batch_item.a.d == batch_dict['d'] | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch2 = Batch( | 
					
						
							|  |  |  |         a=[{ | 
					
						
							|  |  |  |             'b': np.float64(1.0), | 
					
						
							|  |  |  |             'c': np.zeros(1), | 
					
						
							|  |  |  |             'd': Batch(e=np.array(3.0)) | 
					
						
							|  |  |  |         }] | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-06-24 00:55:24 +02:00
										 |  |  |     assert len(batch2) == 1 | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert Batch().shape == [] | 
					
						
							| 
									
										
										
										
											2020-07-19 15:20:35 +08:00
										 |  |  |     assert Batch(a=1).shape == [] | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert Batch(a=set((1, 2, 1))).shape == [] | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert batch2.shape[0] == 1 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert 'a' in batch2 and all([i in batch2.a for i in 'bcd']) | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         batch2[-2] | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         batch2[1] | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert batch2[0].shape == [] | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |         batch2[0][0] | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         len(batch2[0]) | 
					
						
							| 
									
										
										
										
											2020-06-24 00:55:24 +02:00
										 |  |  |     assert isinstance(batch2[0].a.c, np.ndarray) | 
					
						
							|  |  |  |     assert isinstance(batch2[0].a.b, np.float64) | 
					
						
							|  |  |  |     assert isinstance(batch2[0].a.d.e, np.float64) | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |     batch2_from_list = Batch(list(batch2)) | 
					
						
							|  |  |  |     batch2_from_comp = Batch([e for e in batch2]) | 
					
						
							|  |  |  |     assert batch2_from_list.a.b == batch2.a.b | 
					
						
							|  |  |  |     assert batch2_from_list.a.c == batch2.a.c | 
					
						
							|  |  |  |     assert batch2_from_list.a.d.e == batch2.a.d.e | 
					
						
							|  |  |  |     assert batch2_from_comp.a.b == batch2.a.b | 
					
						
							|  |  |  |     assert batch2_from_comp.a.c == batch2.a.c | 
					
						
							|  |  |  |     assert batch2_from_comp.a.d.e == batch2.a.d.e | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     for batch_slice in [batch2[slice(0, 1)], batch2[:1], batch2[0:]]: | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |         assert batch_slice.a.b == batch2.a.b | 
					
						
							|  |  |  |         assert batch_slice.a.c == batch2.a.c | 
					
						
							|  |  |  |         assert batch_slice.a.d.e == batch2.a.d.e | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     batch2.a.d.f = {} | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |     batch2_sum = (batch2 + 1.0) * 2 | 
					
						
							|  |  |  |     assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 | 
					
						
							|  |  |  |     assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 | 
					
						
							|  |  |  |     assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert batch2_sum.a.d.f.is_empty() | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         batch2 += [1] | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch3 = Batch( | 
					
						
							|  |  |  |         a={ | 
					
						
							|  |  |  |             'c': np.zeros(1), | 
					
						
							|  |  |  |             'd': Batch(e=np.array([0.0]), f=np.array([3.0])) | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-06-25 14:39:30 +02:00
										 |  |  |     batch3.a.d[0] = {'e': 4.0} | 
					
						
							|  |  |  |     assert batch3.a.d.e[0] == 4.0 | 
					
						
							|  |  |  |     batch3.a.d[0] = Batch(f=5.0) | 
					
						
							|  |  |  |     assert batch3.a.d.f[0] == 5.0 | 
					
						
							| 
									
										
										
											
												fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
											
										 
											2021-02-02 19:28:05 +08:00
										 |  |  |     with pytest.raises(ValueError): | 
					
						
							| 
									
										
										
										
											2020-06-25 14:39:30 +02:00
										 |  |  |         batch3.a.d[0] = Batch(f=5.0, g=0.0) | 
					
						
							| 
									
										
										
											
												fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
											
										 
											2021-02-02 19:28:05 +08:00
										 |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         batch3[0] = Batch(a={"c": 2, "e": 1}) | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     # auto convert | 
					
						
							|  |  |  |     batch4 = Batch(a=np.array(['a', 'b'])) | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert batch4.a.dtype == object  # auto convert to object | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     batch4.update(a=np.array(['c', 'd'])) | 
					
						
							|  |  |  |     assert list(batch4.a) == ['c', 'd'] | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     assert batch4.a.dtype == object  # auto convert to object | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     batch5 = Batch(a=np.array([{'index': 0}])) | 
					
						
							|  |  |  |     assert isinstance(batch5.a, Batch) | 
					
						
							|  |  |  |     assert np.allclose(batch5.a.index, [0]) | 
					
						
							|  |  |  |     batch5.b = np.array([{'index': 1}]) | 
					
						
							|  |  |  |     assert isinstance(batch5.b, Batch) | 
					
						
							|  |  |  |     assert np.allclose(batch5.b.index, [1]) | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     # None is a valid object and can be stored in Batch | 
					
						
							|  |  |  |     a = Batch.stack([Batch(a=None), Batch(b=None)]) | 
					
						
							|  |  |  |     assert a.a[0] is None and a.a[1] is None | 
					
						
							|  |  |  |     assert a.b[0] is None and a.b[1] is None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-25 15:23:46 +08:00
										 |  |  |     # nx.Graph corner case | 
					
						
							|  |  |  |     assert Batch(a=np.array([nx.Graph(), nx.Graph()], dtype=object)).a.dtype == object | 
					
						
							|  |  |  |     g1 = nx.Graph() | 
					
						
							|  |  |  |     g1.add_nodes_from(list(range(10))) | 
					
						
							|  |  |  |     g2 = nx.Graph() | 
					
						
							|  |  |  |     g2.add_nodes_from(list(range(20))) | 
					
						
							|  |  |  |     assert Batch(a=np.array([g1, g2])).a.dtype == object | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  | def test_batch_over_batch(): | 
					
						
							|  |  |  |     batch = Batch(a=[3, 4, 5], b=[4, 5, 6]) | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     batch2 = Batch({'c': [6, 7, 8], 'b': batch}) | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  |     batch2.b.b[-1] = 0 | 
					
						
							|  |  |  |     print(batch2) | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     for k, v in batch2.items(): | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |         assert np.all(batch2[k] == v) | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  |     assert batch2[-1].b.b == 0 | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     batch2.cat_(Batch(c=[6, 7, 8], b=batch)) | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.b.a, [3, 4, 5, 3, 4, 5]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.b.b, [4, 5, 0, 4, 5, 0]) | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     batch2.update(batch2.b, six=[6, 6, 6]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.c, [6, 7, 8, 6, 7, 8]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.a, [3, 4, 5, 3, 4, 5]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.b, [4, 5, 0, 4, 5, 0]) | 
					
						
							|  |  |  |     assert np.allclose(batch2.six, [6, 6, 6]) | 
					
						
							| 
									
										
										
										
											2020-06-20 22:03:22 +08:00
										 |  |  |     d = {'a': [3, 4, 5], 'b': [4, 5, 6]} | 
					
						
							|  |  |  |     batch3 = Batch(c=[6, 7, 8], b=d) | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  |     batch3.cat_(Batch(c=[6, 7, 8], b=d)) | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     assert np.allclose(batch3.c, [6, 7, 8, 6, 7, 8]) | 
					
						
							|  |  |  |     assert np.allclose(batch3.b.a, [3, 4, 5, 3, 4, 5]) | 
					
						
							|  |  |  |     assert np.allclose(batch3.b.b, [4, 5, 6, 4, 5, 6]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch4 = Batch(({'a': {'b': np.array([1.0])}}, )) | 
					
						
							| 
									
										
										
										
											2020-06-23 17:37:26 +02:00
										 |  |  |     assert batch4.a.b.ndim == 2 | 
					
						
							|  |  |  |     assert batch4.a.b[0, 0] == 1.0 | 
					
						
							| 
									
										
										
										
											2020-06-30 18:02:44 +08:00
										 |  |  |     # advanced slicing | 
					
						
							|  |  |  |     batch5 = Batch(a=[[1, 2]], b={'c': np.zeros([3, 2, 1])}) | 
					
						
							|  |  |  |     assert batch5.shape == [1, 2] | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         batch5[2] | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         batch5[:, 3] | 
					
						
							|  |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         batch5[:, :, -1] | 
					
						
							|  |  |  |     batch5[:, -1] += 1 | 
					
						
							|  |  |  |     assert np.allclose(batch5.a, [1, 3]) | 
					
						
							|  |  |  |     assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         batch5[:, -1] = 1 | 
					
						
							|  |  |  |     batch5[:, 0] = {'a': -1} | 
					
						
							|  |  |  |     assert np.allclose(batch5.a, [-1, 3]) | 
					
						
							|  |  |  |     assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-06 20:30:15 +08:00
										 |  |  | def test_batch_cat_and_stack(): | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     # test cat with compatible keys | 
					
						
							| 
									
										
										
										
											2020-06-24 00:55:24 +02:00
										 |  |  |     b1 = Batch(a=[{'b': np.float64(1.0), 'd': Batch(e=np.array(3.0))}]) | 
					
						
							|  |  |  |     b2 = Batch(a=[{'b': np.float64(4.0), 'd': {'e': np.array(6.0)}}]) | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     b12_cat_out = Batch.cat([b1, b2]) | 
					
						
							| 
									
										
										
										
											2020-06-25 14:39:30 +02:00
										 |  |  |     b12_cat_in = copy.deepcopy(b1) | 
					
						
							|  |  |  |     b12_cat_in.cat_(b2) | 
					
						
							|  |  |  |     assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) | 
					
						
							|  |  |  |     assert np.all(b12_cat_in.a.d.e == b12_cat_out.a.d.e) | 
					
						
							|  |  |  |     assert isinstance(b12_cat_in.a.d.e, np.ndarray) | 
					
						
							|  |  |  |     assert b12_cat_in.a.d.e.ndim == 1 | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     a = Batch(a=Batch(a=np.random.randn(3, 4))) | 
					
						
							|  |  |  |     assert np.allclose( | 
					
						
							|  |  |  |         np.concatenate([a.a.a, a.a.a]), | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |         Batch.cat([a, Batch(a=Batch(a=Batch())), a]).a.a | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # test cat with lens infer | 
					
						
							|  |  |  |     a = Batch(a=Batch(a=np.random.randn(3, 4)), b=np.random.randn(3, 4)) | 
					
						
							|  |  |  |     b = Batch(a=Batch(a=Batch(), t=Batch()), b=np.random.randn(3, 4)) | 
					
						
							|  |  |  |     ans = Batch.cat([a, b, a]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose(ans.a.a, np.concatenate([a.a.a, np.zeros((3, 4)), a.a.a])) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) | 
					
						
							|  |  |  |     assert ans.a.t.is_empty() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert b1.stack_([b2]) is None | 
					
						
							|  |  |  |     assert isinstance(b1.a.d.e, np.ndarray) | 
					
						
							|  |  |  |     assert b1.a.d.e.ndim == 2 | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-12 23:45:42 +08:00
										 |  |  |     # test cat with incompatible keys | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) | 
					
						
							|  |  |  |     b2 = Batch(b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							|  |  |  |     test = Batch.cat([b1, b2]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ans = Batch( | 
					
						
							|  |  |  |         a=np.concatenate([b1.a, np.zeros((4, 4))]), | 
					
						
							|  |  |  |         b=torch.cat([torch.zeros(3, 3), b2.b]), | 
					
						
							|  |  |  |         common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-11 21:46:01 +08:00
										 |  |  |     assert np.allclose(test.a, ans.a) | 
					
						
							|  |  |  |     assert torch.allclose(test.b, ans.b) | 
					
						
							|  |  |  |     assert np.allclose(test.common.c, ans.common.c) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     # test cat with reserved keys (values are Batch()) | 
					
						
							|  |  |  |     b1 = Batch(a=np.random.rand(3, 4), common=Batch(c=np.random.rand(3, 5))) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     test = Batch.cat([b1, b2]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ans = Batch( | 
					
						
							|  |  |  |         a=np.concatenate([b1.a, np.zeros((4, 4))]), | 
					
						
							|  |  |  |         b=torch.cat([torch.zeros(3, 3), b2.b]), | 
					
						
							|  |  |  |         common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     assert np.allclose(test.a, ans.a) | 
					
						
							|  |  |  |     assert torch.allclose(test.b, ans.b) | 
					
						
							|  |  |  |     assert np.allclose(test.common.c, ans.common.c) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # test cat with all reserved keys (values are Batch()) | 
					
						
							|  |  |  |     b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(3, 5))) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     b2 = Batch(a=Batch(), b=torch.rand(4, 3), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     test = Batch.cat([b1, b2]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ans = Batch( | 
					
						
							|  |  |  |         a=Batch(), | 
					
						
							|  |  |  |         b=torch.cat([torch.zeros(3, 3), b2.b]), | 
					
						
							|  |  |  |         common=Batch(c=np.concatenate([b1.common.c, b2.common.c])) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     assert ans.a.is_empty() | 
					
						
							|  |  |  |     assert torch.allclose(test.b, ans.b) | 
					
						
							|  |  |  |     assert np.allclose(test.common.c, ans.common.c) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-12 23:45:42 +08:00
										 |  |  |     # test stack with compatible keys | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     b3 = Batch(a=np.zeros((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[1], [2]])) | 
					
						
							|  |  |  |     b4 = Batch(a=np.ones((3, 4)), b=torch.ones((2, 5)), c=Batch(d=[[0], [3]])) | 
					
						
							| 
									
										
										
										
											2020-06-25 14:39:30 +02:00
										 |  |  |     b34_stack = Batch.stack((b3, b4), axis=1) | 
					
						
							|  |  |  |     assert np.all(b34_stack.a == np.stack((b3.a, b4.a), axis=1)) | 
					
						
							| 
									
										
										
										
											2020-06-26 12:37:50 +02:00
										 |  |  |     assert np.all(b34_stack.c.d == list(map(list, zip(b3.c.d, b4.c.d)))) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     b5_dict = np.array( | 
					
						
							|  |  |  |         [{ | 
					
						
							|  |  |  |             'a': False, | 
					
						
							|  |  |  |             'b': { | 
					
						
							|  |  |  |                 'c': 2.0, | 
					
						
							|  |  |  |                 'd': 1.0 | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }, { | 
					
						
							|  |  |  |             'a': True, | 
					
						
							|  |  |  |             'b': { | 
					
						
							|  |  |  |                 'c': 3.0 | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }] | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-06-27 03:06:40 +02:00
										 |  |  |     b5 = Batch(b5_dict) | 
					
						
							|  |  |  |     assert b5.a[0] == np.array(False) and b5.a[1] == np.array(True) | 
					
						
							|  |  |  |     assert np.all(b5.b.c == np.stack([e['b']['c'] for e in b5_dict], axis=0)) | 
					
						
							|  |  |  |     assert b5.b.d[0] == b5_dict[0]['b']['d'] | 
					
						
							|  |  |  |     assert b5.b.d[1] == 0.0 | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-12 23:45:42 +08:00
										 |  |  |     # test stack with incompatible keys | 
					
						
							|  |  |  |     a = Batch(a=1, b=2, c=3) | 
					
						
							|  |  |  |     b = Batch(a=4, b=5, d=6) | 
					
						
							|  |  |  |     c = Batch(c=7, b=6, d=9) | 
					
						
							|  |  |  |     d = Batch.stack([a, b, c]) | 
					
						
							|  |  |  |     assert np.allclose(d.a, [1, 4, 0]) | 
					
						
							|  |  |  |     assert np.allclose(d.b, [2, 5, 6]) | 
					
						
							|  |  |  |     assert np.allclose(d.c, [3, 0, 7]) | 
					
						
							|  |  |  |     assert np.allclose(d.d, [0, 6, 9]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  |     # test stack with empty Batch() | 
					
						
							|  |  |  |     assert Batch.stack([Batch(), Batch(), Batch()]).is_empty() | 
					
						
							|  |  |  |     a = Batch(a=1, b=2, c=3, d=Batch(), e=Batch()) | 
					
						
							|  |  |  |     b = Batch(a=4, b=5, d=6, e=Batch()) | 
					
						
							|  |  |  |     c = Batch(c=7, b=6, d=9, e=Batch()) | 
					
						
							|  |  |  |     d = Batch.stack([a, b, c]) | 
					
						
							|  |  |  |     assert np.allclose(d.a, [1, 4, 0]) | 
					
						
							|  |  |  |     assert np.allclose(d.b, [2, 5, 6]) | 
					
						
							|  |  |  |     assert np.allclose(d.c, [3, 0, 7]) | 
					
						
							|  |  |  |     assert np.allclose(d.d, [0, 6, 9]) | 
					
						
							|  |  |  |     assert d.e.is_empty() | 
					
						
							|  |  |  |     b1 = Batch(a=Batch(), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							|  |  |  |     b2 = Batch(b=Batch(), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							|  |  |  |     test = Batch.stack([b1, b2], axis=-1) | 
					
						
							|  |  |  |     assert test.a.is_empty() | 
					
						
							|  |  |  |     assert test.b.is_empty() | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     assert np.allclose(test.common.c, np.stack([b1.common.c, b2.common.c], axis=-1)) | 
					
						
							| 
									
										
										
										
											2020-07-16 19:36:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-12 23:45:42 +08:00
										 |  |  |     b1 = Batch(a=np.random.rand(4, 4), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							|  |  |  |     b2 = Batch(b=torch.rand(4, 6), common=Batch(c=np.random.rand(4, 5))) | 
					
						
							|  |  |  |     test = Batch.stack([b1, b2]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     ans = Batch( | 
					
						
							|  |  |  |         a=np.stack([b1.a, np.zeros((4, 4))]), | 
					
						
							|  |  |  |         b=torch.stack([torch.zeros(4, 6), b2.b]), | 
					
						
							|  |  |  |         common=Batch(c=np.stack([b1.common.c, b2.common.c])) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-12 23:45:42 +08:00
										 |  |  |     assert np.allclose(test.a, ans.a) | 
					
						
							|  |  |  |     assert torch.allclose(test.b, ans.b) | 
					
						
							|  |  |  |     assert np.allclose(test.common.c, ans.common.c) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
											
												fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`
before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```
2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```
3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
											
										 
											2021-02-02 19:28:05 +08:00
										 |  |  |     # test with illegal input format | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         Batch.cat([[Batch(a=1)], [Batch(a=1)]]) | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         Batch.stack([[Batch(a=1)], [Batch(a=1)]]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     # exceptions | 
					
						
							|  |  |  |     assert Batch.cat([]).is_empty() | 
					
						
							|  |  |  |     assert Batch.stack([]).is_empty() | 
					
						
							|  |  |  |     b1 = Batch(e=[4, 5], d=6) | 
					
						
							|  |  |  |     b2 = Batch(e=[4, 6]) | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         Batch.cat([b1, b2]) | 
					
						
							|  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |         Batch.stack([b1, b2], axis=1) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-23 16:50:59 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  | def test_batch_over_batch_to_torch(): | 
					
						
							|  |  |  |     batch = Batch( | 
					
						
							| 
									
										
										
										
											2020-07-07 12:40:55 +02:00
										 |  |  |         a=np.float64(1.0), | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |         b=Batch( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             c=np.ones((1, ), dtype=np.float32), | 
					
						
							|  |  |  |             d=torch.ones((1, ), dtype=torch.float64) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     batch.b.__dict__['e'] = 1  # bypass the check | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     batch.to_torch() | 
					
						
							|  |  |  |     assert isinstance(batch.a, torch.Tensor) | 
					
						
							|  |  |  |     assert isinstance(batch.b.c, torch.Tensor) | 
					
						
							|  |  |  |     assert isinstance(batch.b.d, torch.Tensor) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert isinstance(batch.b.e, torch.Tensor) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     assert batch.a.dtype == torch.float64 | 
					
						
							| 
									
										
										
										
											2020-07-07 12:40:55 +02:00
										 |  |  |     assert batch.b.c.dtype == torch.float32 | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     assert batch.b.d.dtype == torch.float64 | 
					
						
							| 
									
										
										
										
											2021-03-31 15:14:22 +08:00
										 |  |  |     if sys.platform in ["win32", "cygwin"]:  # windows | 
					
						
							|  |  |  |         assert batch.b.e.dtype == torch.int32 | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         assert batch.b.e.dtype == torch.int64 | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     batch.to_torch(dtype=torch.float32) | 
					
						
							|  |  |  |     assert batch.a.dtype == torch.float32 | 
					
						
							|  |  |  |     assert batch.b.c.dtype == torch.float32 | 
					
						
							|  |  |  |     assert batch.b.d.dtype == torch.float32 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert batch.b.e.dtype == torch.float32 | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  | def test_utils_to_torch_numpy(): | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     batch = Batch( | 
					
						
							| 
									
										
										
										
											2020-07-07 12:40:55 +02:00
										 |  |  |         a=np.float64(1.0), | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |         b=Batch( | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |             c=np.ones((1, ), dtype=np.float32), | 
					
						
							|  |  |  |             d=torch.ones((1, ), dtype=torch.float64) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |         ) | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     a_torch_float = to_torch(batch.a, dtype=torch.float32) | 
					
						
							|  |  |  |     assert a_torch_float.dtype == torch.float32 | 
					
						
							|  |  |  |     a_torch_double = to_torch(batch.a, dtype=torch.float64) | 
					
						
							|  |  |  |     assert a_torch_double.dtype == torch.float64 | 
					
						
							|  |  |  |     batch_torch_float = to_torch(batch, dtype=torch.float32) | 
					
						
							|  |  |  |     assert batch_torch_float.a.dtype == torch.float32 | 
					
						
							|  |  |  |     assert batch_torch_float.b.c.dtype == torch.float32 | 
					
						
							|  |  |  |     assert batch_torch_float.b.d.dtype == torch.float32 | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     data_list = [float('nan'), 1] | 
					
						
							|  |  |  |     data_list_torch = to_torch(data_list) | 
					
						
							|  |  |  |     assert data_list_torch.dtype == torch.float64 | 
					
						
							|  |  |  |     data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)] | 
					
						
							|  |  |  |     data_list_2_torch = to_torch(data_list_2) | 
					
						
							|  |  |  |     assert data_list_2_torch.shape == (2, 3, 3) | 
					
						
							|  |  |  |     assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2) | 
					
						
							|  |  |  |     data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))] | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     data_list_3_torch = [torch.zeros((3, 2)), torch.zeros((3, 3))] | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_torch(data_list_3) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_numpy(data_list_3_torch) | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))] | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     data_list_4_torch = [torch.zeros((2, 3)), torch.zeros((3, 3))] | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_torch(data_list_4) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_numpy(data_list_4_torch) | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     data_list_5 = [np.zeros(2), np.zeros((3, 3))] | 
					
						
							| 
									
										
										
										
											2021-03-30 16:06:03 +08:00
										 |  |  |     data_list_5_torch = [torch.zeros(2), torch.zeros((3, 3))] | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_torch(data_list_5) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_numpy(data_list_5_torch) | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     data_array = np.random.rand(3, 2, 2) | 
					
						
							|  |  |  |     data_empty_tensor = to_torch(data_array[[]]) | 
					
						
							|  |  |  |     assert isinstance(data_empty_tensor, torch.Tensor) | 
					
						
							|  |  |  |     assert data_empty_tensor.shape == (0, 2, 2) | 
					
						
							|  |  |  |     data_empty_array = to_numpy(data_empty_tensor) | 
					
						
							|  |  |  |     assert isinstance(data_empty_array, np.ndarray) | 
					
						
							|  |  |  |     assert data_empty_array.shape == (0, 2, 2) | 
					
						
							|  |  |  |     assert np.allclose(to_numpy(to_torch(data_array)), data_array) | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     # additional test for to_numpy, for code-coverage | 
					
						
							|  |  |  |     assert isinstance(to_numpy(1), np.ndarray) | 
					
						
							|  |  |  |     assert isinstance(to_numpy(1.), np.ndarray) | 
					
						
							|  |  |  |     assert isinstance(to_numpy({'a': torch.tensor(1)})['a'], np.ndarray) | 
					
						
							|  |  |  |     assert isinstance(to_numpy(Batch(a=torch.tensor(1))).a, np.ndarray) | 
					
						
							|  |  |  |     assert to_numpy(None).item() is None | 
					
						
							|  |  |  |     assert to_numpy(to_numpy).item() == to_numpy | 
					
						
							|  |  |  |     # additional test for to_torch, for code-coverage | 
					
						
							|  |  |  |     assert isinstance(to_torch(1), torch.Tensor) | 
					
						
							| 
									
										
										
										
											2021-03-31 15:14:22 +08:00
										 |  |  |     if sys.platform in ["win32", "cygwin"]:  # windows | 
					
						
							|  |  |  |         assert to_torch(1).dtype == torch.int32 | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         assert to_torch(1).dtype == torch.int64 | 
					
						
							| 
									
										
										
										
											2020-08-27 12:15:18 +08:00
										 |  |  |     assert to_torch(1.).dtype == torch.float64 | 
					
						
							|  |  |  |     assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_torch(None) | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         to_torch(np.array([{}, '2'])) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-30 15:29:33 +02:00
										 |  |  | def test_batch_pickle(): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4])) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:29:33 +02:00
										 |  |  |     batch_pk = pickle.loads(pickle.dumps(batch)) | 
					
						
							|  |  |  |     assert batch.obs.a == batch_pk.obs.a | 
					
						
							|  |  |  |     assert torch.all(batch.obs.c == batch_pk.obs.c) | 
					
						
							|  |  |  |     assert np.all(batch.np == batch_pk.np) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-05-29 14:45:21 +02:00
										 |  |  | def test_batch_from_to_numpy_without_copy(): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch = Batch(a=np.ones((1, )), b=Batch(c=np.ones((1, )))) | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     a_mem_addr_orig = batch.a.__array_interface__['data'][0] | 
					
						
							|  |  |  |     c_mem_addr_orig = batch.b.c.__array_interface__['data'][0] | 
					
						
							| 
									
										
										
										
											2020-05-29 14:45:21 +02:00
										 |  |  |     batch.to_torch() | 
					
						
							|  |  |  |     batch.to_numpy() | 
					
						
							| 
									
										
										
										
											2020-05-30 15:40:31 +02:00
										 |  |  |     a_mem_addr_new = batch.a.__array_interface__['data'][0] | 
					
						
							|  |  |  |     c_mem_addr_new = batch.b.c.__array_interface__['data'][0] | 
					
						
							| 
									
										
										
										
											2020-05-29 14:45:21 +02:00
										 |  |  |     assert a_mem_addr_new == a_mem_addr_orig | 
					
						
							|  |  |  |     assert c_mem_addr_new == c_mem_addr_orig | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-06 20:30:15 +08:00
										 |  |  | def test_batch_copy(): | 
					
						
							|  |  |  |     batch = Batch(a=np.array([3, 4, 5]), b=np.array([4, 5, 6])) | 
					
						
							|  |  |  |     batch2 = Batch({'c': np.array([6, 7, 8]), 'b': batch}) | 
					
						
							|  |  |  |     orig_c_addr = batch2.c.__array_interface__['data'][0] | 
					
						
							|  |  |  |     orig_b_a_addr = batch2.b.a.__array_interface__['data'][0] | 
					
						
							|  |  |  |     orig_b_b_addr = batch2.b.b.__array_interface__['data'][0] | 
					
						
							|  |  |  |     # test with copy=False | 
					
						
							|  |  |  |     batch3 = Batch(copy=False, **batch2) | 
					
						
							|  |  |  |     curr_c_addr = batch3.c.__array_interface__['data'][0] | 
					
						
							|  |  |  |     curr_b_a_addr = batch3.b.a.__array_interface__['data'][0] | 
					
						
							|  |  |  |     curr_b_b_addr = batch3.b.b.__array_interface__['data'][0] | 
					
						
							|  |  |  |     assert batch2.c is batch3.c | 
					
						
							|  |  |  |     assert batch2.b is batch3.b | 
					
						
							|  |  |  |     assert batch2.b.a is batch3.b.a | 
					
						
							|  |  |  |     assert batch2.b.b is batch3.b.b | 
					
						
							|  |  |  |     assert orig_c_addr == curr_c_addr | 
					
						
							|  |  |  |     assert orig_b_a_addr == curr_b_a_addr | 
					
						
							|  |  |  |     assert orig_b_b_addr == curr_b_b_addr | 
					
						
							|  |  |  |     # test with copy=True | 
					
						
							|  |  |  |     batch3 = Batch(copy=True, **batch2) | 
					
						
							|  |  |  |     curr_c_addr = batch3.c.__array_interface__['data'][0] | 
					
						
							|  |  |  |     curr_b_a_addr = batch3.b.a.__array_interface__['data'][0] | 
					
						
							|  |  |  |     curr_b_b_addr = batch3.b.b.__array_interface__['data'][0] | 
					
						
							|  |  |  |     assert batch2.c is not batch3.c | 
					
						
							|  |  |  |     assert batch2.b is not batch3.b | 
					
						
							|  |  |  |     assert batch2.b.a is not batch3.b.a | 
					
						
							|  |  |  |     assert batch2.b.b is not batch3.b.b | 
					
						
							|  |  |  |     assert orig_c_addr != curr_c_addr | 
					
						
							|  |  |  |     assert orig_b_a_addr != curr_b_a_addr | 
					
						
							|  |  |  |     assert orig_b_b_addr != curr_b_b_addr | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_batch_empty(): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     b5_dict = np.array( | 
					
						
							|  |  |  |         [{ | 
					
						
							|  |  |  |             'a': False, | 
					
						
							|  |  |  |             'b': { | 
					
						
							|  |  |  |                 'c': 2.0, | 
					
						
							|  |  |  |                 'd': 1.0 | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }, { | 
					
						
							|  |  |  |             'a': True, | 
					
						
							|  |  |  |             'b': { | 
					
						
							|  |  |  |                 'c': 3.0 | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         }] | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-06 20:30:15 +08:00
										 |  |  |     b5 = Batch(b5_dict) | 
					
						
							|  |  |  |     b5[1] = Batch.empty(b5[0]) | 
					
						
							|  |  |  |     assert np.allclose(b5.a, [False, False]) | 
					
						
							|  |  |  |     assert np.allclose(b5.b.c, [2, 0]) | 
					
						
							|  |  |  |     assert np.allclose(b5.b.d, [1, 0]) | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     data = Batch( | 
					
						
							|  |  |  |         a=[False, True], | 
					
						
							|  |  |  |         b={ | 
					
						
							|  |  |  |             'c': np.array([2., 'st'], dtype=object), | 
					
						
							|  |  |  |             'd': [1, None], | 
					
						
							|  |  |  |             'e': [2., float('nan')] | 
					
						
							|  |  |  |         }, | 
					
						
							|  |  |  |         c=np.array([1, 3, 4], dtype=int), | 
					
						
							|  |  |  |         t=torch.tensor([4, 5, 6, 7.]) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-07-06 20:30:15 +08:00
										 |  |  |     data[-1] = Batch.empty(data[1]) | 
					
						
							|  |  |  |     assert np.allclose(data.c, [1, 3, 0]) | 
					
						
							|  |  |  |     assert np.allclose(data.a, [False, False]) | 
					
						
							|  |  |  |     assert list(data.b.c) == [2.0, None] | 
					
						
							|  |  |  |     assert list(data.b.d) == [1, None] | 
					
						
							|  |  |  |     assert np.allclose(data.b.e, [2, 0]) | 
					
						
							|  |  |  |     assert torch.allclose(data.t, torch.tensor([4, 5, 6, 0.])) | 
					
						
							|  |  |  |     data[0].empty_()  # which will fail in a, b.c, b.d, b.e, c | 
					
						
							|  |  |  |     assert torch.allclose(data.t, torch.tensor([0., 5, 6, 0])) | 
					
						
							|  |  |  |     data.empty_(index=0) | 
					
						
							|  |  |  |     assert np.allclose(data.c, [0, 3, 0]) | 
					
						
							|  |  |  |     assert list(data.b.c) == [None, None] | 
					
						
							|  |  |  |     assert list(data.b.d) == [None, None] | 
					
						
							|  |  |  |     assert list(data.b.e) == [0, 0] | 
					
						
							|  |  |  |     b0 = Batch() | 
					
						
							|  |  |  |     b0.empty_() | 
					
						
							|  |  |  |     assert b0.shape == [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-07-08 16:29:37 +02:00
										 |  |  | def test_batch_standard_compatibility(): | 
					
						
							| 
									
										
										
										
											2021-09-03 05:05:04 +08:00
										 |  |  |     batch = Batch( | 
					
						
							|  |  |  |         a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0]) | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  |     batch_mean = np.mean(batch) | 
					
						
							|  |  |  |     assert isinstance(batch_mean, Batch) | 
					
						
							|  |  |  |     assert sorted(batch_mean.keys()) == ['a', 'b', 'c'] | 
					
						
							|  |  |  |     with pytest.raises(TypeError): | 
					
						
							|  |  |  |         len(batch_mean) | 
					
						
							|  |  |  |     assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) | 
					
						
							|  |  |  |     assert batch_mean.c == np.mean(batch.c, axis=0) | 
					
						
							| 
									
										
										
										
											2020-07-08 16:29:37 +02:00
										 |  |  |     with pytest.raises(IndexError): | 
					
						
							|  |  |  |         Batch()[0] | 
					
						
							| 
									
										
										
										
											2020-06-24 15:43:48 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-13 17:49:22 +08:00
										 |  |  | if __name__ == '__main__': | 
					
						
							|  |  |  |     test_batch() | 
					
						
							| 
									
										
										
										
											2020-05-27 11:02:23 +08:00
										 |  |  |     test_batch_over_batch() | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |     test_batch_over_batch_to_torch() | 
					
						
							| 
									
										
										
										
											2020-07-21 10:47:56 +02:00
										 |  |  |     test_utils_to_torch_numpy() | 
					
						
							| 
									
										
										
										
											2020-06-01 08:30:09 +08:00
										 |  |  |     test_batch_pickle() | 
					
						
							|  |  |  |     test_batch_from_to_numpy_without_copy() | 
					
						
							| 
									
										
										
										
											2020-07-08 16:29:37 +02:00
										 |  |  |     test_batch_standard_compatibility() | 
					
						
							| 
									
										
										
										
											2020-07-06 20:30:15 +08:00
										 |  |  |     test_batch_cat_and_stack() | 
					
						
							|  |  |  |     test_batch_copy() | 
					
						
							|  |  |  |     test_batch_empty() |