Improve to_torch/to_numpy converters (#147)

* Enable converting list/tuple back and forth from/to numpy/torch.

* Add fallbacks.

* Fix PEP8

* Update unit tests.

* Type annotation. Robust dtype check.

* List of object are converted individually, as a single tensor otherwise.

* Improve robustness of _to_array_with_correct_type

* Add unit tests.

* Do not catch exception at _to_array_with_correct_type level.

* Use _parse_value

* Fix PEP8

* Fix _parse_value list output type fallback.

* Catch torch exception.

* Do not convert torch tensor during fallback.

* Improve unit tests.

* Add unit tests.

* FIx missing import

* Remove support of numpy arrays of tensors for Batch value parser.

* Forbid numpy arrays of tensors.

* Fix PEP8.

* Fix comment.

* Reduce _parse_value branch number.

* Fix None value.

* Forward error message for debugging purpose.

* Fix _is_scalar.

* More specific try/catch blocks.

* Fix exception chaining.

* Fix PEP8.

* Fix _is_scalar.

* Fix missing corner case.

* Fix PEP8.

* Allow Batch empty key.

* Fix multi-dim array datatype check.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-21 10:47:56 +02:00 committed by GitHub
parent 8c32d99c65
commit 865ef6c693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 37 deletions

View File

@ -3,8 +3,9 @@ import copy
import pickle import pickle
import pytest import pytest
import numpy as np import numpy as np
from itertools import starmap
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch, to_numpy
def test_batch(): def test_batch():
@ -28,8 +29,19 @@ def test_batch():
assert b.a == 3 assert b.a == 3
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
Batch({1: 2}) Batch({1: 2})
with pytest.raises(TypeError):
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
batch = Batch(a=[torch.ones(3), torch.ones(3)]) batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3)) assert torch.allclose(batch.a, torch.ones(2, 3))
Batch(a=[])
batch = Batch(obs=[0], np=np.zeros([3, 4])) batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"] assert batch.obs == batch["obs"]
batch.obs = [1] batch.obs = [1]
@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch():
assert batch.b.d.dtype == torch.float32 assert batch.b.d.dtype == torch.float32
def test_utils_to_torch(): def test_utils_to_torch_numpy():
batch = Batch( batch = Batch(
a=np.float64(1.0), a=np.float64(1.0),
b=Batch( b=Batch(
@ -323,8 +335,37 @@ def test_utils_to_torch():
assert batch_torch_float.a.dtype == torch.float32 assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32 assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32 assert batch_torch_float.b.d.dtype == torch.float32
array_list = [float('nan'), 1.0] data_list = [float('nan'), 1]
assert to_torch(array_list).dtype == torch.float64 data_list_torch = to_torch(data_list)
assert data_list_torch.dtype == torch.float64
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
data_list_2_torch = to_torch(data_list_2)
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = to_torch(data_list_3)
assert isinstance(data_list_3_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = to_torch(data_list_4)
assert isinstance(data_list_4_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = to_torch(data_list_5)
assert isinstance(data_list_5_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
assert data_empty_tensor.shape == (0, 2, 2)
data_empty_array = to_numpy(data_empty_tensor)
assert isinstance(data_empty_array, np.ndarray)
assert data_empty_array.shape == (0, 2, 2)
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
def test_batch_pickle(): def test_batch_pickle():
@ -432,7 +473,7 @@ if __name__ == '__main__':
test_batch() test_batch()
test_batch_over_batch() test_batch_over_batch()
test_batch_over_batch_to_torch() test_batch_over_batch_to_torch()
test_utils_to_torch() test_utils_to_torch_numpy()
test_batch_pickle() test_batch_pickle()
test_batch_from_to_numpy_without_copy() test_batch_from_to_numpy_without_copy()
test_batch_standard_compatibility() test_batch_standard_compatibility()

View File

@ -4,6 +4,7 @@ import warnings
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from numbers import Number from numbers import Number
from collections.abc import Collection
from typing import Any, List, Tuple, Union, Iterator, Optional from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed # Disable pickle warning related to torch, since it has been removed
@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool:
# 3. python object rather than dict / Batch / tensor # 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value. # the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values # a dict / Batch will eventually check their values
value = np.asanyarray(value) if isinstance(value, torch.Tensor):
return value.size == 1 and not value.shape return value.numel() == 1 and not value.shape
else:
value = np.asanyarray(value)
return value.size == 1 and not value.shape
def _is_number(value: Any) -> bool: def _is_number(value: Any) -> bool:
@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool:
def _to_array_with_correct_type(v: Any) -> np.ndarray: def _to_array_with_correct_type(v: Any) -> np.ndarray:
# convert the value to np.ndarray # convert the value to np.ndarray
# convert to np.object data type if neither bool nor number # convert to np.object data type if neither bool nor number
# raises an exception if array's elements are tensors themself
v = np.asanyarray(v) v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)): if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object) v = v.astype(np.object)
if v.dtype == np.object and not v.shape: if v.dtype == np.object:
# scalar ndarray with np.object data type is very annoying # scalar ndarray with np.object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)]) # a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in # a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange: # something very strange:
# array([{}, array({}, dtype=object)], dtype=object) # array([{}, array({}, dtype=object)], dtype=object)
v = v.item(0) if not v.shape:
v = v.item(0)
elif any(isinstance(e, (np.ndarray, torch.Tensor))
for e in v.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return v return v
@ -113,25 +122,29 @@ def _assert_type_keys(keys):
def _parse_value(v: Any): def _parse_value(v: Any):
if isinstance(v, (list, tuple, np.ndarray)): if isinstance(v, dict):
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v = torch.stack(v)
return v
v_ = _to_array_with_correct_type(v)
if v_.dtype == np.object and _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# normal data list (main case)
# or actually a data list with objects
v = v_
elif isinstance(v, dict):
v = Batch(v) v = Batch(v)
elif isinstance(v, (Batch, torch.Tensor)): elif isinstance(v, (Batch, torch.Tensor)):
pass pass
else: else:
# scalar case, convert to ndarray if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
v = _to_array_with_correct_type(v) len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
return torch.stack(v)
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
try:
v_ = _to_array_with_correct_type(v)
except ValueError as e:
raise TypeError("Batch does not support heterogeneous list/tuple"
" of tensors as unique value yet.") from e
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# None, scalar, normal data list (main case)
# or an actual list of objects
v = v_
return v return v

View File

@ -3,12 +3,12 @@ import numpy as np
from numbers import Number from numbers import Number
from typing import Union, Optional from typing import Union, Optional
from tianshou.data import Batch from tianshou.data.batch import _parse_value, Batch
def to_numpy(x: Union[ def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[ Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
dict, Batch, np.ndarray]: Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without torch.Tensor.""" """Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy() x = x.detach().cpu().numpy()
@ -17,13 +17,20 @@ def to_numpy(x: Union[
x[k] = to_numpy(v) x[k] = to_numpy(v)
elif isinstance(x, Batch): elif isinstance(x, Batch):
x.to_numpy() x.to_numpy()
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x return x
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu' device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]: ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without np.ndarray.""" """Return an object without np.ndarray."""
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
if dtype is not None: if dtype is not None:
@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x.to_torch(dtype, device) x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)): elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device) x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \ elif isinstance(x, (list, tuple)):
all(isinstance(e, (np.number, np.bool_, Number)) for e in x): try:
x = to_torch(np.asanyarray(x), dtype, device) x = to_torch(_parse_value(x), dtype, device)
elif isinstance(x, np.ndarray) and \ except TypeError:
isinstance(x.item(0), (np.number, np.bool_, Number)): x = [to_torch(e, dtype, device) for e in x]
x = torch.from_numpy(x).to(device) else: # fallback
if dtype is not None: x = np.asanyarray(x)
x = x.type(dtype) if issubclass(x.dtype.type, (np.bool_, np.number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
else:
raise TypeError(f"object {x} cannot be converted to torch.")
return x return x