Improve to_torch/to_numpy converters (#147)

* Enable converting list/tuple back and forth from/to numpy/torch.

* Add fallbacks.

* Fix PEP8

* Update unit tests.

* Type annotation. Robust dtype check.

* List of object are converted individually, as a single tensor otherwise.

* Improve robustness of _to_array_with_correct_type

* Add unit tests.

* Do not catch exception at _to_array_with_correct_type level.

* Use _parse_value

* Fix PEP8

* Fix _parse_value list output type fallback.

* Catch torch exception.

* Do not convert torch tensor during fallback.

* Improve unit tests.

* Add unit tests.

* FIx missing import

* Remove support of numpy arrays of tensors for Batch value parser.

* Forbid numpy arrays of tensors.

* Fix PEP8.

* Fix comment.

* Reduce _parse_value branch number.

* Fix None value.

* Forward error message for debugging purpose.

* Fix _is_scalar.

* More specific try/catch blocks.

* Fix exception chaining.

* Fix PEP8.

* Fix _is_scalar.

* Fix missing corner case.

* Fix PEP8.

* Allow Batch empty key.

* Fix multi-dim array datatype check.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-21 10:47:56 +02:00 committed by GitHub
parent 8c32d99c65
commit 865ef6c693
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 37 deletions

View File

@ -3,8 +3,9 @@ import copy
import pickle
import pytest
import numpy as np
from itertools import starmap
from tianshou.data import Batch, to_torch
from tianshou.data import Batch, to_torch, to_numpy
def test_batch():
@ -28,8 +29,19 @@ def test_batch():
assert b.a == 3
with pytest.raises(AssertionError):
Batch({1: 2})
with pytest.raises(TypeError):
Batch(a=[np.zeros((2, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[np.zeros((3, 2)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((2, 3)), torch.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[torch.zeros((3, 3)), np.zeros((3, 3))])
with pytest.raises(TypeError):
Batch(a=[1, np.zeros((3, 3)), torch.zeros((3, 3))])
batch = Batch(a=[torch.ones(3), torch.ones(3)])
assert torch.allclose(batch.a, torch.ones(2, 3))
Batch(a=[])
batch = Batch(obs=[0], np=np.zeros([3, 4]))
assert batch.obs == batch["obs"]
batch.obs = [1]
@ -307,7 +319,7 @@ def test_batch_over_batch_to_torch():
assert batch.b.d.dtype == torch.float32
def test_utils_to_torch():
def test_utils_to_torch_numpy():
batch = Batch(
a=np.float64(1.0),
b=Batch(
@ -323,8 +335,37 @@ def test_utils_to_torch():
assert batch_torch_float.a.dtype == torch.float32
assert batch_torch_float.b.c.dtype == torch.float32
assert batch_torch_float.b.d.dtype == torch.float32
array_list = [float('nan'), 1.0]
assert to_torch(array_list).dtype == torch.float64
data_list = [float('nan'), 1]
data_list_torch = to_torch(data_list)
assert data_list_torch.dtype == torch.float64
data_list_2 = [np.random.rand(3, 3), np.random.rand(3, 3)]
data_list_2_torch = to_torch(data_list_2)
assert data_list_2_torch.shape == (2, 3, 3)
assert np.allclose(to_numpy(to_torch(data_list_2)), data_list_2)
data_list_3 = [np.zeros((3, 2)), np.zeros((3, 3))]
data_list_3_torch = to_torch(data_list_3)
assert isinstance(data_list_3_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_3_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_3)), data_list_3)))
data_list_4 = [np.zeros((2, 3)), np.zeros((3, 3))]
data_list_4_torch = to_torch(data_list_4)
assert isinstance(data_list_4_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_4_torch)
assert all(starmap(np.allclose,
zip(to_numpy(to_torch(data_list_4)), data_list_4)))
data_list_5 = [np.zeros(2), np.zeros((3, 3))]
data_list_5_torch = to_torch(data_list_5)
assert isinstance(data_list_5_torch, list)
assert all(isinstance(e, torch.Tensor) for e in data_list_5_torch)
data_array = np.random.rand(3, 2, 2)
data_empty_tensor = to_torch(data_array[[]])
assert isinstance(data_empty_tensor, torch.Tensor)
assert data_empty_tensor.shape == (0, 2, 2)
data_empty_array = to_numpy(data_empty_tensor)
assert isinstance(data_empty_array, np.ndarray)
assert data_empty_array.shape == (0, 2, 2)
assert np.allclose(to_numpy(to_torch(data_array)), data_array)
def test_batch_pickle():
@ -432,7 +473,7 @@ if __name__ == '__main__':
test_batch()
test_batch_over_batch()
test_batch_over_batch_to_torch()
test_utils_to_torch()
test_utils_to_torch_numpy()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_standard_compatibility()

View File

@ -4,6 +4,7 @@ import warnings
import numpy as np
from copy import deepcopy
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Tuple, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
@ -36,8 +37,11 @@ def _is_scalar(value: Any) -> bool:
# 3. python object rather than dict / Batch / tensor
# the check of dict / Batch is omitted because this only checks a value.
# a dict / Batch will eventually check their values
value = np.asanyarray(value)
return value.size == 1 and not value.shape
if isinstance(value, torch.Tensor):
return value.numel() == 1 and not value.shape
else:
value = np.asanyarray(value)
return value.size == 1 and not value.shape
def _is_number(value: Any) -> bool:
@ -53,16 +57,21 @@ def _is_number(value: Any) -> bool:
def _to_array_with_correct_type(v: Any) -> np.ndarray:
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
# raises an exception if array's elements are tensors themself
v = np.asanyarray(v)
if not issubclass(v.dtype.type, (np.bool_, np.number)):
v = v.astype(np.object)
if v.dtype == np.object and not v.shape:
if v.dtype == np.object:
# scalar ndarray with np.object data type is very annoying
# a=np.array([np.array({}, dtype=object), np.array({}, dtype=object)])
# a is not array([{}, {}], dtype=object), and a[0]={} results in
# something very strange:
# array([{}, array({}, dtype=object)], dtype=object)
v = v.item(0)
if not v.shape:
v = v.item(0)
elif any(isinstance(e, (np.ndarray, torch.Tensor))
for e in v.reshape(-1)):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return v
@ -113,25 +122,29 @@ def _assert_type_keys(keys):
def _parse_value(v: Any):
if isinstance(v, (list, tuple, np.ndarray)):
if not isinstance(v, np.ndarray) and \
all(isinstance(e, torch.Tensor) for e in v):
v = torch.stack(v)
return v
v_ = _to_array_with_correct_type(v)
if v_.dtype == np.object and _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# normal data list (main case)
# or actually a data list with objects
v = v_
elif isinstance(v, dict):
if isinstance(v, dict):
v = Batch(v)
elif isinstance(v, (Batch, torch.Tensor)):
pass
else:
# scalar case, convert to ndarray
v = _to_array_with_correct_type(v)
if not isinstance(v, np.ndarray) and isinstance(v, Collection) and \
len(v) > 0 and all(isinstance(e, torch.Tensor) for e in v):
try:
return torch.stack(v)
except RuntimeError as e:
raise TypeError("Batch does not support non-stackable iterable"
" of torch.Tensor as unique value yet.") from e
try:
v_ = _to_array_with_correct_type(v)
except ValueError as e:
raise TypeError("Batch does not support heterogeneous list/tuple"
" of tensors as unique value yet.") from e
if _is_batch_set(v):
v = Batch(v) # list of dict / Batch
else:
# None, scalar, normal data list (main case)
# or an actual list of objects
v = v_
return v

View File

@ -3,12 +3,12 @@ import numpy as np
from numbers import Number
from typing import Union, Optional
from tianshou.data import Batch
from tianshou.data.batch import _parse_value, Batch
def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
dict, Batch, np.ndarray]:
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
@ -17,13 +17,20 @@ def to_numpy(x: Union[
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, torch.Tensor):
if dtype is not None:
@ -36,14 +43,19 @@ def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
x.to_torch(dtype, device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, list) and len(x) > 0 and \
all(isinstance(e, (np.number, np.bool_, Number)) for e in x):
x = to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, np.ndarray) and \
isinstance(x.item(0), (np.number, np.bool_, Number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
elif isinstance(x, (list, tuple)):
try:
x = to_torch(_parse_value(x), dtype, device)
except TypeError:
x = [to_torch(e, dtype, device) for e in x]
else: # fallback
x = np.asanyarray(x)
if issubclass(x.dtype.type, (np.bool_, np.number)):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
else:
raise TypeError(f"object {x} cannot be converted to torch.")
return x