fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)` before: ```python TypeError: cannot concatenate with Batch() which is scalar ``` after: ```python Batch( a: Batch(), b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ) ``` 2. creating keys in a batch's subkey, e.g. ```python a = Batch(info={"key1": [0, 1], "key2": [2, 3]}) a[0] = Batch(info={"key1": 2, "key3": 4}) print(a) ``` before: ```python Batch( info: Batch( key1: array([0, 1]), key2: array([0, 3]), ), ) ``` after: ```python ValueError: Creating keys is not supported by item assignment. ``` 3. small optimization for `Batch.stack_` and `Batch.cat_`
This commit is contained in:
parent
f528131da1
commit
d003c8e566
@ -19,6 +19,8 @@ def test_batch():
|
||||
assert not Batch(a=np.float64(1.0)).is_empty()
|
||||
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
|
||||
assert b.c.dtype == np.object
|
||||
b = Batch()
|
||||
b.update()
|
||||
assert b.is_empty()
|
||||
@ -143,8 +145,10 @@ def test_batch():
|
||||
assert batch3.a.d.e[0] == 4.0
|
||||
batch3.a.d[0] = Batch(f=5.0)
|
||||
assert batch3.a.d.f[0] == 5.0
|
||||
with pytest.raises(KeyError):
|
||||
with pytest.raises(ValueError):
|
||||
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
||||
with pytest.raises(ValueError):
|
||||
batch3[0] = Batch(a={"c": 2, "e": 1})
|
||||
# auto convert
|
||||
batch4 = Batch(a=np.array(['a', 'b']))
|
||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||
@ -333,6 +337,12 @@ def test_batch_cat_and_stack():
|
||||
assert torch.allclose(test.b, ans.b)
|
||||
assert np.allclose(test.common.c, ans.common.c)
|
||||
|
||||
# test with illegal input format
|
||||
with pytest.raises(ValueError):
|
||||
Batch.cat([[Batch(a=1)], [Batch(a=1)]])
|
||||
with pytest.raises(ValueError):
|
||||
Batch.stack([[Batch(a=1)], [Batch(a=1)]])
|
||||
|
||||
# exceptions
|
||||
assert Batch.cat([]).is_empty()
|
||||
assert Batch.stack([]).is_empty()
|
||||
|
@ -8,12 +8,6 @@ from collections.abc import Collection
|
||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
||||
Sequence
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
# on torch master branch. See Pull Request #39003 for details:
|
||||
# https://github.com/pytorch/pytorch/pull/39003
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="pickle support for Storage will be removed in 1.5.")
|
||||
|
||||
|
||||
def _is_batch_set(data: Any) -> bool:
|
||||
# Batch set is a list/tuple of dict/Batch objects,
|
||||
@ -91,6 +85,9 @@ def _create_value(
|
||||
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
||||
is_scalar = _is_scalar(inst)
|
||||
if not stack and is_scalar:
|
||||
# _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here
|
||||
if isinstance(inst, Batch) and inst.is_empty(recurse=True):
|
||||
return inst
|
||||
# should never hit since it has already checked in Batch.cat_
|
||||
# here we do not consider scalar types, following the behavior of numpy
|
||||
# which does not support concatenation of zero-dimensional arrays
|
||||
@ -257,7 +254,7 @@ class Batch:
|
||||
raise ValueError("Batch does not supported tensor assignment. "
|
||||
"Use a compatible Batch or dict instead.")
|
||||
if not set(value.keys()).issubset(self.__dict__.keys()):
|
||||
raise KeyError(
|
||||
raise ValueError(
|
||||
"Creating keys is not supported by item assignment.")
|
||||
for key, val in self.items():
|
||||
try:
|
||||
@ -449,12 +446,21 @@ class Batch:
|
||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||
if isinstance(batches, Batch):
|
||||
batches = [batches]
|
||||
if len(batches) == 0:
|
||||
# check input format
|
||||
batch_list = []
|
||||
for b in batches:
|
||||
if isinstance(b, dict):
|
||||
if len(b) > 0:
|
||||
batch_list.append(Batch(b))
|
||||
elif isinstance(b, Batch):
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
if not b.is_empty():
|
||||
batch_list.append(b)
|
||||
else:
|
||||
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
|
||||
if len(batch_list) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
batches = [x for x in batches if not x.is_empty()]
|
||||
batches = batch_list
|
||||
try:
|
||||
# x.is_empty(recurse=True) here means x is a nested empty batch
|
||||
# like Batch(a=Batch), and we have to treat it as length zero and
|
||||
@ -496,9 +502,22 @@ class Batch:
|
||||
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
|
||||
) -> None:
|
||||
"""Stack a list of Batch object into current batch."""
|
||||
if len(batches) == 0:
|
||||
# check input format
|
||||
batch_list = []
|
||||
for b in batches:
|
||||
if isinstance(b, dict):
|
||||
if len(b) > 0:
|
||||
batch_list.append(Batch(b))
|
||||
elif isinstance(b, Batch):
|
||||
# x.is_empty() means that x is Batch() and should be ignored
|
||||
if not b.is_empty():
|
||||
batch_list.append(b)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot concatenate {type(b)} in Batch.stack_")
|
||||
if len(batch_list) == 0:
|
||||
return
|
||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
||||
batches = batch_list
|
||||
if not self.is_empty():
|
||||
batches = [self] + batches
|
||||
# collect non-empty keys
|
||||
|
@ -203,7 +203,7 @@ class ReplayBuffer:
|
||||
)
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
except ValueError:
|
||||
for key in set(inst.keys()).difference(value.__dict__.keys()):
|
||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||
value[self._index] = inst
|
||||
|
Loading…
x
Reference in New Issue
Block a user