fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)` before: ```python TypeError: cannot concatenate with Batch() which is scalar ``` after: ```python Batch( a: Batch(), b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ) ``` 2. creating keys in a batch's subkey, e.g. ```python a = Batch(info={"key1": [0, 1], "key2": [2, 3]}) a[0] = Batch(info={"key1": 2, "key3": 4}) print(a) ``` before: ```python Batch( info: Batch( key1: array([0, 1]), key2: array([0, 3]), ), ) ``` after: ```python ValueError: Creating keys is not supported by item assignment. ``` 3. small optimization for `Batch.stack_` and `Batch.cat_`
This commit is contained in:
parent
f528131da1
commit
d003c8e566
@ -19,6 +19,8 @@ def test_batch():
|
|||||||
assert not Batch(a=np.float64(1.0)).is_empty()
|
assert not Batch(a=np.float64(1.0)).is_empty()
|
||||||
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
assert len(Batch(a=[1, 2, 3], b={'c': {}})) == 3
|
||||||
assert not Batch(a=[1, 2, 3]).is_empty()
|
assert not Batch(a=[1, 2, 3]).is_empty()
|
||||||
|
b = Batch({'a': [4, 4], 'b': [5, 5]}, c=[None, None])
|
||||||
|
assert b.c.dtype == np.object
|
||||||
b = Batch()
|
b = Batch()
|
||||||
b.update()
|
b.update()
|
||||||
assert b.is_empty()
|
assert b.is_empty()
|
||||||
@ -143,8 +145,10 @@ def test_batch():
|
|||||||
assert batch3.a.d.e[0] == 4.0
|
assert batch3.a.d.e[0] == 4.0
|
||||||
batch3.a.d[0] = Batch(f=5.0)
|
batch3.a.d[0] = Batch(f=5.0)
|
||||||
assert batch3.a.d.f[0] == 5.0
|
assert batch3.a.d.f[0] == 5.0
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(ValueError):
|
||||||
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
batch3.a.d[0] = Batch(f=5.0, g=0.0)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
batch3[0] = Batch(a={"c": 2, "e": 1})
|
||||||
# auto convert
|
# auto convert
|
||||||
batch4 = Batch(a=np.array(['a', 'b']))
|
batch4 = Batch(a=np.array(['a', 'b']))
|
||||||
assert batch4.a.dtype == np.object # auto convert to np.object
|
assert batch4.a.dtype == np.object # auto convert to np.object
|
||||||
@ -333,6 +337,12 @@ def test_batch_cat_and_stack():
|
|||||||
assert torch.allclose(test.b, ans.b)
|
assert torch.allclose(test.b, ans.b)
|
||||||
assert np.allclose(test.common.c, ans.common.c)
|
assert np.allclose(test.common.c, ans.common.c)
|
||||||
|
|
||||||
|
# test with illegal input format
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Batch.cat([[Batch(a=1)], [Batch(a=1)]])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Batch.stack([[Batch(a=1)], [Batch(a=1)]])
|
||||||
|
|
||||||
# exceptions
|
# exceptions
|
||||||
assert Batch.cat([]).is_empty()
|
assert Batch.cat([]).is_empty()
|
||||||
assert Batch.stack([]).is_empty()
|
assert Batch.stack([]).is_empty()
|
||||||
|
@ -8,12 +8,6 @@ from collections.abc import Collection
|
|||||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
||||||
Sequence
|
Sequence
|
||||||
|
|
||||||
# Disable pickle warning related to torch, since it has been removed
|
|
||||||
# on torch master branch. See Pull Request #39003 for details:
|
|
||||||
# https://github.com/pytorch/pytorch/pull/39003
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore", message="pickle support for Storage will be removed in 1.5.")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_batch_set(data: Any) -> bool:
|
def _is_batch_set(data: Any) -> bool:
|
||||||
# Batch set is a list/tuple of dict/Batch objects,
|
# Batch set is a list/tuple of dict/Batch objects,
|
||||||
@ -91,6 +85,9 @@ def _create_value(
|
|||||||
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
has_shape = isinstance(inst, (np.ndarray, torch.Tensor))
|
||||||
is_scalar = _is_scalar(inst)
|
is_scalar = _is_scalar(inst)
|
||||||
if not stack and is_scalar:
|
if not stack and is_scalar:
|
||||||
|
# _create_value(Batch(a={}, b=[1, 2, 3]), 10, False) will fail here
|
||||||
|
if isinstance(inst, Batch) and inst.is_empty(recurse=True):
|
||||||
|
return inst
|
||||||
# should never hit since it has already checked in Batch.cat_
|
# should never hit since it has already checked in Batch.cat_
|
||||||
# here we do not consider scalar types, following the behavior of numpy
|
# here we do not consider scalar types, following the behavior of numpy
|
||||||
# which does not support concatenation of zero-dimensional arrays
|
# which does not support concatenation of zero-dimensional arrays
|
||||||
@ -257,7 +254,7 @@ class Batch:
|
|||||||
raise ValueError("Batch does not supported tensor assignment. "
|
raise ValueError("Batch does not supported tensor assignment. "
|
||||||
"Use a compatible Batch or dict instead.")
|
"Use a compatible Batch or dict instead.")
|
||||||
if not set(value.keys()).issubset(self.__dict__.keys()):
|
if not set(value.keys()).issubset(self.__dict__.keys()):
|
||||||
raise KeyError(
|
raise ValueError(
|
||||||
"Creating keys is not supported by item assignment.")
|
"Creating keys is not supported by item assignment.")
|
||||||
for key, val in self.items():
|
for key, val in self.items():
|
||||||
try:
|
try:
|
||||||
@ -449,12 +446,21 @@ class Batch:
|
|||||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||||
if isinstance(batches, Batch):
|
if isinstance(batches, Batch):
|
||||||
batches = [batches]
|
batches = [batches]
|
||||||
if len(batches) == 0:
|
# check input format
|
||||||
return
|
batch_list = []
|
||||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
for b in batches:
|
||||||
|
if isinstance(b, dict):
|
||||||
|
if len(b) > 0:
|
||||||
|
batch_list.append(Batch(b))
|
||||||
|
elif isinstance(b, Batch):
|
||||||
# x.is_empty() means that x is Batch() and should be ignored
|
# x.is_empty() means that x is Batch() and should be ignored
|
||||||
batches = [x for x in batches if not x.is_empty()]
|
if not b.is_empty():
|
||||||
|
batch_list.append(b)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot concatenate {type(b)} in Batch.cat_")
|
||||||
|
if len(batch_list) == 0:
|
||||||
|
return
|
||||||
|
batches = batch_list
|
||||||
try:
|
try:
|
||||||
# x.is_empty(recurse=True) here means x is a nested empty batch
|
# x.is_empty(recurse=True) here means x is a nested empty batch
|
||||||
# like Batch(a=Batch), and we have to treat it as length zero and
|
# like Batch(a=Batch), and we have to treat it as length zero and
|
||||||
@ -496,9 +502,22 @@ class Batch:
|
|||||||
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
|
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Stack a list of Batch object into current batch."""
|
"""Stack a list of Batch object into current batch."""
|
||||||
if len(batches) == 0:
|
# check input format
|
||||||
|
batch_list = []
|
||||||
|
for b in batches:
|
||||||
|
if isinstance(b, dict):
|
||||||
|
if len(b) > 0:
|
||||||
|
batch_list.append(Batch(b))
|
||||||
|
elif isinstance(b, Batch):
|
||||||
|
# x.is_empty() means that x is Batch() and should be ignored
|
||||||
|
if not b.is_empty():
|
||||||
|
batch_list.append(b)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot concatenate {type(b)} in Batch.stack_")
|
||||||
|
if len(batch_list) == 0:
|
||||||
return
|
return
|
||||||
batches = [x if isinstance(x, Batch) else Batch(x) for x in batches]
|
batches = batch_list
|
||||||
if not self.is_empty():
|
if not self.is_empty():
|
||||||
batches = [self] + batches
|
batches = [self] + batches
|
||||||
# collect non-empty keys
|
# collect non-empty keys
|
||||||
|
@ -203,7 +203,7 @@ class ReplayBuffer:
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
except KeyError:
|
except ValueError:
|
||||||
for key in set(inst.keys()).difference(value.__dict__.keys()):
|
for key in set(inst.keys()).difference(value.__dict__.keys()):
|
||||||
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
value.__dict__[key] = _create_value(inst[key], self._maxsize)
|
||||||
value[self._index] = inst
|
value[self._index] = inst
|
||||||
|
Loading…
x
Reference in New Issue
Block a user