fix 2 bugs of batch (#284)

1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`

before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```

2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```

3. small optimization for `Batch.stack_` and `Batch.cat_`
This commit is contained in:
n+e 2021-02-16 09:01:54 +08:00 committed by GitHub
parent f528131da1
commit d003c8e566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 16 deletions

View File

@ -19,6 +19,8 @@ def test_batch():
assert not Batch(a=np.float64(1.0)).is_empty()
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
assert not Batch(a=[1, 2, 3]).is_empty()
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
assert b.c.dtype == np.object
b = Batch()
b.update()
assert b.is_empty()
@ -143,8 +145,10 @@ def test_batch():
assert batch3.a.d.e[0] == 4.0
batch3.a.d[0] = Batch(f=5.0)
assert batch3.a.d.f[0] == 5.0
with pytest.raises(KeyError):
with pytest.raises(ValueError):
batch3.a.d[0] = Batch(f=5.0, g=0.0)
with pytest.raises(ValueError):
batch3[0] = Batch(a={"c": 2, "e": 1})
# auto convert
batch4 = Batch(a=np.array(['a', 'b']))
assert batch4.a.dtype == np.object # auto convert to np.object
@ -333,6 +337,12 @@ def test_batch_cat_and_stack():
assert torch.allclose(test.b, ans.b)
assert np.allclose(test.common.c, ans.common.c)
# test with illegal input format
with pytest.raises(ValueError):
Batch.cat([[Batch(a=1)], [Batch(a=1)]])
with pytest.raises(ValueError):
Batch.stack([[Batch(a=1)], [Batch(a=1)]])
# exceptions
assert Batch.cat([]).is_empty()
assert Batch.stack([]).is_empty()

View File

@ -8,12 +8,6 @@ from collections.abc import Collection
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
def _is_batch_set(data: Any) -> bool:
# Batch set is a list/tuple of dict/Batch objects,
@ -91,6 +85,9 @@ def _create_value(
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
is_scalar = _is_scalar(inst)
if not stack and is_scalar:
# _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here
if isinstance(inst, Batch) and inst.is_empty(recurse=True):
return inst
# should never hit since it has already checked in Batch.cat_
# here we do not consider scalar types, following the behavior of numpy
# which does not support concatenation of zero-dimensional arrays
@ -257,7 +254,7 @@ class Batch:
raise ValueError("Batch does not supported tensor assignment. "
"Use a compatible Batch or dict instead.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
raise ValueError(
"Creating keys is not supported by item assignment.")
for key, val in self.items():
try:
@ -449,12 +446,21 @@ class Batch:
"""Concatenate a list of (or one) Batch objects into current batch."""
if isinstance(batches, Batch):
batches = [batches]
if len(batches) == 0:
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
else:
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
if len(batch_list) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
# x.is_empty() means that x is Batch() and should be ignored
batches = [x for x in batches if not x.is_empty()]
batches = batch_list
try:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
@ -496,9 +502,22 @@ class Batch:
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
) -> None:
"""Stack a list of Batch object into current batch."""
if len(batches) == 0:
# check input format
batch_list = []
for b in batches:
if isinstance(b, dict):
if len(b) > 0:
batch_list.append(Batch(b))
elif isinstance(b, Batch):
# x.is_empty() means that x is Batch() and should be ignored
if not b.is_empty():
batch_list.append(b)
else:
raise ValueError(
f"Cannot concatenate {type(b)} in Batch.stack_")
if len(batch_list) == 0:
return
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
batches = batch_list
if not self.is_empty():
batches = [self] + batches
# collect non-empty keys

View File

@ -203,7 +203,7 @@ class ReplayBuffer:
)
try:
value[self._index] = inst
except KeyError:
except ValueError:
for key in set(inst.keys()).difference(value.__dict__.keys()):
value.__dict__[key] = _create_value(inst[key], self._maxsize)
value[self._index] = inst